diff --git a/docs/examples/Predict_Contributions.ipynb b/docs/examples/Predict_Contributions.ipynb
new file mode 100644
index 0000000..57ecfec
--- /dev/null
+++ b/docs/examples/Predict_Contributions.ipynb
@@ -0,0 +1,1158 @@
+{
+ "cells": [
+ {
+ "metadata": {},
+ "cell_type": "markdown",
+ "source": [
+ "# Show how to extract prediction contributions for each distribution parameter\n",
+ "\n",
+ "This example shows how to get the contribution of every feature for each distributional parameter for a given data set. This allows similar inferences as one might get from SHAP but comes from lightGBM's internal workings. We can use output for example to get for a given prediction which features are causing the most impact to a given distributional parameter.\n",
+ "\n",
+ "These contributions are created before the response function is applied. As such in the case of the identity function, for a given row of data the sum of the contributions should equal the parameter value.\n"
+ ],
+ "id": "bf95ab4267d5a34"
+ },
+ {
+ "metadata": {},
+ "cell_type": "markdown",
+ "source": [
+ "# Imports\n",
+ "\n",
+ "First, we import necessary functions. "
+ ],
+ "id": "bbea43740b87eb"
+ },
+ {
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-09-26T10:08:32.499517Z",
+ "start_time": "2024-09-26T10:08:32.484942Z"
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "import numpy as np\n",
+ "\n",
+ "from lightgbmlss.model import *\n",
+ "from lightgbmlss.distributions.Gaussian import *\n",
+ "from lightgbmlss.datasets.data_loader import load_simulated_gaussian_data\n",
+ "from scipy.stats import norm\n",
+ "\n",
+ "import plotnine\n",
+ "from plotnine import *\n",
+ "\n",
+ "plotnine.options.figure_size = (12, 8)"
+ ],
+ "id": "b5f2d07ce70bb24b",
+ "outputs": [],
+ "execution_count": 45
+ },
+ {
+ "metadata": {},
+ "cell_type": "markdown",
+ "source": "# Data",
+ "id": "bd7bba77a5e0fa2f"
+ },
+ {
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-09-26T10:08:32.595603Z",
+ "start_time": "2024-09-26T10:08:32.563920Z"
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "train, test = load_simulated_gaussian_data()\n",
+ "\n",
+ "X_train, y_train = train.filter(regex=\"x\"), train[\"y\"].values\n",
+ "X_test, y_test = test.filter(regex=\"x\"), test[\"y\"].values\n",
+ "\n",
+ "dtrain = lgb.Dataset(X_train, label=y_train)"
+ ],
+ "id": "1062b4b851a12bc9",
+ "outputs": [],
+ "execution_count": 46
+ },
+ {
+ "metadata": {},
+ "cell_type": "markdown",
+ "source": [
+ "# Get a Trained Model\n",
+ "\n",
+ "As this example is about th uses of a trained model, we wont do any hyper-parameter searching. We will also use a Gaussian distribution as the response function of the loc parameter is the identity function, this will allow us to more easily compare the outputs of a standard parameter prediction to a contributions prediction."
+ ],
+ "id": "170feafe1dccf85c"
+ },
+ {
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-09-26T10:09:16.586326Z",
+ "start_time": "2024-09-26T10:08:32.595603Z"
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "lgblss = LightGBMLSS(\n",
+ " Gaussian()\n",
+ ")\n",
+ "lgblss.train(\n",
+ " params=dict(),\n",
+ " train_set=dtrain\n",
+ ")\n",
+ "\n",
+ "param_dict = {\n",
+ " \"eta\": [\"float\", {\"low\": 1e-5, \"high\": 1, \"log\": True}],\n",
+ " \"max_depth\": [\"int\", {\"low\": 1, \"high\": 10, \"log\": False}],\n",
+ " \"num_leaves\": [\"int\", {\"low\": 255, \"high\": 255, \"log\": False}], # set to constant for this example\n",
+ " \"min_data_in_leaf\": [\"int\", {\"low\": 20, \"high\": 20, \"log\": False}], # set to constant for this example\n",
+ " \"min_gain_to_split\": [\"float\", {\"low\": 1e-8, \"high\": 40, \"log\": False}],\n",
+ " \"min_sum_hessian_in_leaf\": [\"float\", {\"low\": 1e-8, \"high\": 500, \"log\": True}],\n",
+ " \"subsample\": [\"float\", {\"low\": 0.2, \"high\": 1.0, \"log\": False}],\n",
+ " \"feature_fraction\": [\"float\", {\"low\": 0.2, \"high\": 1.0, \"log\": False}],\n",
+ " \"boosting\": [\"categorical\", [\"gbdt\"]],\n",
+ "}\n",
+ "\n",
+ "np.random.seed(123)\n",
+ "opt_param = lgblss.hyper_opt(param_dict,\n",
+ " dtrain,\n",
+ " num_boost_round=100, # Number of boosting iterations.\n",
+ " nfold=5, # Number of cv-folds.\n",
+ " early_stopping_rounds=20, # Number of early-stopping rounds\n",
+ " max_minutes=10, # Time budget in minutes, i.e., stop study after the given number of minutes.\n",
+ " n_trials=30 , # The number of trials. If this argument is set to None, there is no limitation on the number of trials.\n",
+ " silence=True, # Controls the verbosity of the trail, i.e., user can silence the outputs of the trail.\n",
+ " seed=123, # Seed used to generate cv-folds.\n",
+ " hp_seed=123 # Seed for random number generator used in the Bayesian hyperparameter search.\n",
+ " )\n",
+ "\n",
+ "np.random.seed(123)\n",
+ "\n",
+ "opt_params = opt_param.copy()\n",
+ "n_rounds = opt_params[\"opt_rounds\"]\n",
+ "del opt_params[\"opt_rounds\"]\n",
+ "\n",
+ "# Train Model with optimized hyperparameters\n",
+ "lgblss.train(opt_params,\n",
+ " dtrain,\n",
+ " num_boost_round=n_rounds\n",
+ " )\n"
+ ],
+ "id": "f45c868160f1f08b",
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ " 0%| | 0/30 [00:00, ?it/s]"
+ ],
+ "application/vnd.jupyter.widget-view+json": {
+ "version_major": 2,
+ "version_minor": 0,
+ "model_id": "45def0cbae7345c2af90d41ce5c331b0"
+ }
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "Hyper-Parameter Optimization successfully finished.\n",
+ " Number of finished trials: 30\n",
+ " Best trial:\n",
+ " Value: 2.0839106241730967\n",
+ " Params: \n",
+ " eta: 0.042322345196562056\n",
+ " max_depth: 3\n",
+ " num_leaves: 255\n",
+ " min_data_in_leaf: 20\n",
+ " min_gain_to_split: 10.495083287505906\n",
+ " min_sum_hessian_in_leaf: 4.025662198099785e-06\n",
+ " subsample: 0.41879883505881144\n",
+ " feature_fraction: 0.7628021535153005\n",
+ " boosting: gbdt\n",
+ " opt_rounds: 72\n"
+ ]
+ }
+ ],
+ "execution_count": 47
+ },
+ {
+ "metadata": {},
+ "cell_type": "markdown",
+ "source": "# Get parameter predictions and parameter contribution predictions",
+ "id": "3c8358d79ec85438"
+ },
+ {
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-09-26T10:09:16.618477Z",
+ "start_time": "2024-09-26T10:09:16.586326Z"
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "pred_params = lgblss.predict(X_test, pred_type=\"parameters\")\n",
+ "pred_param_contributions = lgblss.predict(X_test, pred_type=\"contributions\")\n"
+ ],
+ "id": "c0bab6ad5807cd8d",
+ "outputs": [],
+ "execution_count": 48
+ },
+ {
+ "metadata": {},
+ "cell_type": "markdown",
+ "source": "As location parameter uses identity function the sum of these predictions should equal the value in pred_params. However this is not true for the scale params, as response functions have not been applied when contributions are created.",
+ "id": "c39c97f68b1ba929"
+ },
+ {
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-09-26T10:09:16.639879Z",
+ "start_time": "2024-09-26T10:09:16.618477Z"
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "sum_of_contributions = pred_param_contributions.groupby(level=\"distribution_arg\", axis=1).sum()\n",
+ "location_values_are_all_close = np.allclose(pred_params[\"loc\"], sum_of_contributions[\"loc\"])\n",
+ "scale_values_are_all_close = np.allclose(pred_params[\"scale\"], sum_of_contributions[\"scale\"])\n",
+ "\n",
+ "\n",
+ "print(f\"{location_values_are_all_close=}\")\n",
+ "print(f\"{scale_values_are_all_close=}\")\n"
+ ],
+ "id": "87e216c88a4ff947",
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "location_values_are_all_close=True\n",
+ "scale_values_are_all_close=False\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "C:\\Users\\SimonRobertPike\\AppData\\Local\\Temp\\ipykernel_47316\\1135199838.py:1: FutureWarning: DataFrame.groupby with axis=1 is deprecated. Do `frame.T.groupby(...)` without axis instead.\n"
+ ]
+ }
+ ],
+ "execution_count": 49
+ },
+ {
+ "metadata": {},
+ "cell_type": "markdown",
+ "source": "### Show contributions for each feature for location parameter",
+ "id": "90e4b7d2544afd58"
+ },
+ {
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-09-26T10:09:16.672598Z",
+ "start_time": "2024-09-26T10:09:16.642194Z"
+ }
+ },
+ "cell_type": "code",
+ "source": "pred_param_contributions.xs(\"loc\", axis=1, level=\"distribution_arg\")",
+ "id": "1b6b7013a1f7e957",
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "FeatureContribution x_true x_noise1 x_noise2 x_noise3 x_noise4 x_noise5 \\\n",
+ "0 0.0 0.0 0.0 0.0 0.0 0.0 \n",
+ "1 0.0 0.0 0.0 0.0 0.0 0.0 \n",
+ "2 0.0 0.0 0.0 0.0 0.0 0.0 \n",
+ "3 0.0 0.0 0.0 0.0 0.0 0.0 \n",
+ "4 0.0 0.0 0.0 0.0 0.0 0.0 \n",
+ "... ... ... ... ... ... ... \n",
+ "2995 0.0 0.0 0.0 0.0 0.0 0.0 \n",
+ "2996 0.0 0.0 0.0 0.0 0.0 0.0 \n",
+ "2997 0.0 0.0 0.0 0.0 0.0 0.0 \n",
+ "2998 0.0 0.0 0.0 0.0 0.0 0.0 \n",
+ "2999 0.0 0.0 0.0 0.0 0.0 0.0 \n",
+ "\n",
+ "FeatureContribution x_noise6 x_noise7 x_noise8 x_noise9 x_noise10 \\\n",
+ "0 0.0 0.0 0.0 0.0 0.0 \n",
+ "1 0.0 0.0 0.0 0.0 0.0 \n",
+ "2 0.0 0.0 0.0 0.0 0.0 \n",
+ "3 0.0 0.0 0.0 0.0 0.0 \n",
+ "4 0.0 0.0 0.0 0.0 0.0 \n",
+ "... ... ... ... ... ... \n",
+ "2995 0.0 0.0 0.0 0.0 0.0 \n",
+ "2996 0.0 0.0 0.0 0.0 0.0 \n",
+ "2997 0.0 0.0 0.0 0.0 0.0 \n",
+ "2998 0.0 0.0 0.0 0.0 0.0 \n",
+ "2999 0.0 0.0 0.0 0.0 0.0 \n",
+ "\n",
+ "FeatureContribution Const \n",
+ "0 9.979578 \n",
+ "1 9.979578 \n",
+ "2 9.979578 \n",
+ "3 9.979578 \n",
+ "4 9.979578 \n",
+ "... ... \n",
+ "2995 9.979578 \n",
+ "2996 9.979578 \n",
+ "2997 9.979578 \n",
+ "2998 9.979578 \n",
+ "2999 9.979578 \n",
+ "\n",
+ "[3000 rows x 12 columns]"
+ ],
+ "text/html": [
+ "
\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | FeatureContribution | \n",
+ " x_true | \n",
+ " x_noise1 | \n",
+ " x_noise2 | \n",
+ " x_noise3 | \n",
+ " x_noise4 | \n",
+ " x_noise5 | \n",
+ " x_noise6 | \n",
+ " x_noise7 | \n",
+ " x_noise8 | \n",
+ " x_noise9 | \n",
+ " x_noise10 | \n",
+ " Const | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " | 0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 9.979578 | \n",
+ "
\n",
+ " \n",
+ " | 1 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 9.979578 | \n",
+ "
\n",
+ " \n",
+ " | 2 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 9.979578 | \n",
+ "
\n",
+ " \n",
+ " | 3 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 9.979578 | \n",
+ "
\n",
+ " \n",
+ " | 4 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 9.979578 | \n",
+ "
\n",
+ " \n",
+ " | ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ "
\n",
+ " \n",
+ " | 2995 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 9.979578 | \n",
+ "
\n",
+ " \n",
+ " | 2996 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 9.979578 | \n",
+ "
\n",
+ " \n",
+ " | 2997 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 9.979578 | \n",
+ "
\n",
+ " \n",
+ " | 2998 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 9.979578 | \n",
+ "
\n",
+ " \n",
+ " | 2999 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 9.979578 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
3000 rows × 12 columns
\n",
+ "
"
+ ]
+ },
+ "execution_count": 50,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "execution_count": 50
+ },
+ {
+ "metadata": {},
+ "cell_type": "markdown",
+ "source": "### Show contributions for each feature for scale parameter",
+ "id": "eaf2ad3ecc736152"
+ },
+ {
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-09-26T10:09:16.706240Z",
+ "start_time": "2024-09-26T10:09:16.673598Z"
+ }
+ },
+ "cell_type": "code",
+ "source": "pred_param_contributions.xs(\"scale\", axis=1, level=\"distribution_arg\")\n",
+ "id": "c5453f7e5a378096",
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "FeatureContribution x_true x_noise1 x_noise2 x_noise3 x_noise4 \\\n",
+ "0 0.410556 0.002107 0.0 0.0 0.000034 \n",
+ "1 0.411267 0.000683 0.0 0.0 -0.000340 \n",
+ "2 -0.597710 0.002107 0.0 0.0 -0.000340 \n",
+ "3 0.848812 0.002835 0.0 0.0 0.000034 \n",
+ "4 0.414533 0.001566 0.0 0.0 0.000867 \n",
+ "... ... ... ... ... ... \n",
+ "2995 0.411235 0.002835 0.0 0.0 -0.000340 \n",
+ "2996 0.380668 0.002107 0.0 0.0 -0.000340 \n",
+ "2997 -0.597620 0.001648 0.0 0.0 0.000034 \n",
+ "2998 -0.607374 -0.001427 0.0 0.0 -0.001144 \n",
+ "2999 0.410556 0.002107 0.0 0.0 0.000034 \n",
+ "\n",
+ "FeatureContribution x_noise5 x_noise6 x_noise7 x_noise8 x_noise9 \\\n",
+ "0 0.000197 0.004104 -0.000126 0.0 -0.000608 \n",
+ "1 0.000197 0.004816 -0.000126 0.0 -0.000608 \n",
+ "2 0.000197 0.004104 -0.000126 0.0 -0.000608 \n",
+ "3 0.000197 0.001399 -0.000126 0.0 -0.000608 \n",
+ "4 0.000123 0.002717 -0.004173 0.0 0.053938 \n",
+ "... ... ... ... ... ... \n",
+ "2995 0.000197 0.002135 -0.000126 0.0 -0.000608 \n",
+ "2996 0.000197 0.004402 -0.000126 0.0 -0.000608 \n",
+ "2997 0.000197 -0.004548 -0.000126 0.0 -0.000700 \n",
+ "2998 0.000888 0.002017 0.003200 0.0 -0.000029 \n",
+ "2999 0.000197 0.004104 -0.000126 0.0 -0.000608 \n",
+ "\n",
+ "FeatureContribution x_noise10 Const \n",
+ "0 -0.000503 0.653625 \n",
+ "1 -0.000130 0.653625 \n",
+ "2 -0.000130 0.653625 \n",
+ "3 0.001530 0.653625 \n",
+ "4 0.001895 0.653625 \n",
+ "... ... ... \n",
+ "2995 0.000432 0.653625 \n",
+ "2996 0.002378 0.653625 \n",
+ "2997 -0.000892 0.653625 \n",
+ "2998 -0.004399 0.653625 \n",
+ "2999 -0.000503 0.653625 \n",
+ "\n",
+ "[3000 rows x 12 columns]"
+ ],
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | FeatureContribution | \n",
+ " x_true | \n",
+ " x_noise1 | \n",
+ " x_noise2 | \n",
+ " x_noise3 | \n",
+ " x_noise4 | \n",
+ " x_noise5 | \n",
+ " x_noise6 | \n",
+ " x_noise7 | \n",
+ " x_noise8 | \n",
+ " x_noise9 | \n",
+ " x_noise10 | \n",
+ " Const | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " | 0 | \n",
+ " 0.410556 | \n",
+ " 0.002107 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.000034 | \n",
+ " 0.000197 | \n",
+ " 0.004104 | \n",
+ " -0.000126 | \n",
+ " 0.0 | \n",
+ " -0.000608 | \n",
+ " -0.000503 | \n",
+ " 0.653625 | \n",
+ "
\n",
+ " \n",
+ " | 1 | \n",
+ " 0.411267 | \n",
+ " 0.000683 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " -0.000340 | \n",
+ " 0.000197 | \n",
+ " 0.004816 | \n",
+ " -0.000126 | \n",
+ " 0.0 | \n",
+ " -0.000608 | \n",
+ " -0.000130 | \n",
+ " 0.653625 | \n",
+ "
\n",
+ " \n",
+ " | 2 | \n",
+ " -0.597710 | \n",
+ " 0.002107 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " -0.000340 | \n",
+ " 0.000197 | \n",
+ " 0.004104 | \n",
+ " -0.000126 | \n",
+ " 0.0 | \n",
+ " -0.000608 | \n",
+ " -0.000130 | \n",
+ " 0.653625 | \n",
+ "
\n",
+ " \n",
+ " | 3 | \n",
+ " 0.848812 | \n",
+ " 0.002835 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.000034 | \n",
+ " 0.000197 | \n",
+ " 0.001399 | \n",
+ " -0.000126 | \n",
+ " 0.0 | \n",
+ " -0.000608 | \n",
+ " 0.001530 | \n",
+ " 0.653625 | \n",
+ "
\n",
+ " \n",
+ " | 4 | \n",
+ " 0.414533 | \n",
+ " 0.001566 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.000867 | \n",
+ " 0.000123 | \n",
+ " 0.002717 | \n",
+ " -0.004173 | \n",
+ " 0.0 | \n",
+ " 0.053938 | \n",
+ " 0.001895 | \n",
+ " 0.653625 | \n",
+ "
\n",
+ " \n",
+ " | ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ "
\n",
+ " \n",
+ " | 2995 | \n",
+ " 0.411235 | \n",
+ " 0.002835 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " -0.000340 | \n",
+ " 0.000197 | \n",
+ " 0.002135 | \n",
+ " -0.000126 | \n",
+ " 0.0 | \n",
+ " -0.000608 | \n",
+ " 0.000432 | \n",
+ " 0.653625 | \n",
+ "
\n",
+ " \n",
+ " | 2996 | \n",
+ " 0.380668 | \n",
+ " 0.002107 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " -0.000340 | \n",
+ " 0.000197 | \n",
+ " 0.004402 | \n",
+ " -0.000126 | \n",
+ " 0.0 | \n",
+ " -0.000608 | \n",
+ " 0.002378 | \n",
+ " 0.653625 | \n",
+ "
\n",
+ " \n",
+ " | 2997 | \n",
+ " -0.597620 | \n",
+ " 0.001648 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.000034 | \n",
+ " 0.000197 | \n",
+ " -0.004548 | \n",
+ " -0.000126 | \n",
+ " 0.0 | \n",
+ " -0.000700 | \n",
+ " -0.000892 | \n",
+ " 0.653625 | \n",
+ "
\n",
+ " \n",
+ " | 2998 | \n",
+ " -0.607374 | \n",
+ " -0.001427 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " -0.001144 | \n",
+ " 0.000888 | \n",
+ " 0.002017 | \n",
+ " 0.003200 | \n",
+ " 0.0 | \n",
+ " -0.000029 | \n",
+ " -0.004399 | \n",
+ " 0.653625 | \n",
+ "
\n",
+ " \n",
+ " | 2999 | \n",
+ " 0.410556 | \n",
+ " 0.002107 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.000034 | \n",
+ " 0.000197 | \n",
+ " 0.004104 | \n",
+ " -0.000126 | \n",
+ " 0.0 | \n",
+ " -0.000608 | \n",
+ " -0.000503 | \n",
+ " 0.653625 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
3000 rows × 12 columns
\n",
+ "
"
+ ]
+ },
+ "execution_count": 51,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "execution_count": 51
+ },
+ {
+ "metadata": {},
+ "cell_type": "markdown",
+ "source": "# Show Mean Feature Impact for Data set",
+ "id": "394e64d247168fa0"
+ },
+ {
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-09-26T10:09:16.722325Z",
+ "start_time": "2024-09-26T10:09:16.706781Z"
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "sum_of_contributions_column = \"SumOfContributions\"\n",
+ "mean_parameter_contribution = pred_param_contributions.abs().mean().unstack(\"distribution_arg\")\n",
+ "mean_parameter_contribution[sum_of_contributions_column] = mean_parameter_contribution.sum(1)\n",
+ "\n",
+ "mean_parameter_contribution.sort_values(sum_of_contributions_column, ascending=False).drop(columns=sum_of_contributions_column)\n"
+ ],
+ "id": "54d4970cf1957735",
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "distribution_arg loc scale\n",
+ "FeatureContribution \n",
+ "Const 9.979577 0.653625\n",
+ "x_true 0.000000 0.591884\n",
+ "x_noise6 0.000000 0.004868\n",
+ "x_noise7 0.000000 0.004415\n",
+ "x_noise1 0.000000 0.003994\n",
+ "x_noise10 0.000000 0.002689\n",
+ "x_noise9 0.000000 0.002583\n",
+ "x_noise4 0.000000 0.001668\n",
+ "x_noise5 0.000000 0.000585\n",
+ "x_noise2 0.000000 0.000000\n",
+ "x_noise3 0.000000 0.000000\n",
+ "x_noise8 0.000000 0.000000"
+ ],
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | distribution_arg | \n",
+ " loc | \n",
+ " scale | \n",
+ "
\n",
+ " \n",
+ " | FeatureContribution | \n",
+ " | \n",
+ " | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " | Const | \n",
+ " 9.979577 | \n",
+ " 0.653625 | \n",
+ "
\n",
+ " \n",
+ " | x_true | \n",
+ " 0.000000 | \n",
+ " 0.591884 | \n",
+ "
\n",
+ " \n",
+ " | x_noise6 | \n",
+ " 0.000000 | \n",
+ " 0.004868 | \n",
+ "
\n",
+ " \n",
+ " | x_noise7 | \n",
+ " 0.000000 | \n",
+ " 0.004415 | \n",
+ "
\n",
+ " \n",
+ " | x_noise1 | \n",
+ " 0.000000 | \n",
+ " 0.003994 | \n",
+ "
\n",
+ " \n",
+ " | x_noise10 | \n",
+ " 0.000000 | \n",
+ " 0.002689 | \n",
+ "
\n",
+ " \n",
+ " | x_noise9 | \n",
+ " 0.000000 | \n",
+ " 0.002583 | \n",
+ "
\n",
+ " \n",
+ " | x_noise4 | \n",
+ " 0.000000 | \n",
+ " 0.001668 | \n",
+ "
\n",
+ " \n",
+ " | x_noise5 | \n",
+ " 0.000000 | \n",
+ " 0.000585 | \n",
+ "
\n",
+ " \n",
+ " | x_noise2 | \n",
+ " 0.000000 | \n",
+ " 0.000000 | \n",
+ "
\n",
+ " \n",
+ " | x_noise3 | \n",
+ " 0.000000 | \n",
+ " 0.000000 | \n",
+ "
\n",
+ " \n",
+ " | x_noise8 | \n",
+ " 0.000000 | \n",
+ " 0.000000 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ]
+ },
+ "execution_count": 52,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "execution_count": 52
+ },
+ {
+ "metadata": {},
+ "cell_type": "markdown",
+ "source": "# Get correlation between contributions for the scale parameter ",
+ "id": "f7c73f303f04d4ff"
+ },
+ {
+ "metadata": {},
+ "cell_type": "markdown",
+ "source": "",
+ "id": "8d5dc9e448d5c322"
+ },
+ {
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-09-26T10:09:16.737953Z",
+ "start_time": "2024-09-26T10:09:16.722325Z"
+ }
+ },
+ "cell_type": "code",
+ "source": "pred_param_contributions.xs(\"scale\", axis=1, level=\"distribution_arg\").corr().dropna(how=\"all\").dropna(axis=1,how=\"all\")\n",
+ "id": "f331d8603042908",
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "FeatureContribution x_true x_noise1 x_noise4 x_noise5 x_noise6 \\\n",
+ "FeatureContribution \n",
+ "x_true 1.000000 0.007743 -0.001231 -0.047812 -0.021563 \n",
+ "x_noise1 0.007743 1.000000 -0.006635 -0.022209 0.136772 \n",
+ "x_noise4 -0.001231 -0.006635 1.000000 -0.015972 -0.030669 \n",
+ "x_noise5 -0.047812 -0.022209 -0.015972 1.000000 0.006214 \n",
+ "x_noise6 -0.021563 0.136772 -0.030669 0.006214 1.000000 \n",
+ "x_noise7 0.015347 0.002129 0.474525 0.021844 0.029845 \n",
+ "x_noise9 0.024365 -0.006972 0.013082 0.015998 0.009551 \n",
+ "x_noise10 0.035477 0.012110 -0.035711 -0.001439 0.028450 \n",
+ "\n",
+ "FeatureContribution x_noise7 x_noise9 x_noise10 \n",
+ "FeatureContribution \n",
+ "x_true 0.015347 0.024365 0.035477 \n",
+ "x_noise1 0.002129 -0.006972 0.012110 \n",
+ "x_noise4 0.474525 0.013082 -0.035711 \n",
+ "x_noise5 0.021844 0.015998 -0.001439 \n",
+ "x_noise6 0.029845 0.009551 0.028450 \n",
+ "x_noise7 1.000000 0.023553 -0.015334 \n",
+ "x_noise9 0.023553 1.000000 -0.030410 \n",
+ "x_noise10 -0.015334 -0.030410 1.000000 "
+ ],
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | FeatureContribution | \n",
+ " x_true | \n",
+ " x_noise1 | \n",
+ " x_noise4 | \n",
+ " x_noise5 | \n",
+ " x_noise6 | \n",
+ " x_noise7 | \n",
+ " x_noise9 | \n",
+ " x_noise10 | \n",
+ "
\n",
+ " \n",
+ " | FeatureContribution | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " | x_true | \n",
+ " 1.000000 | \n",
+ " 0.007743 | \n",
+ " -0.001231 | \n",
+ " -0.047812 | \n",
+ " -0.021563 | \n",
+ " 0.015347 | \n",
+ " 0.024365 | \n",
+ " 0.035477 | \n",
+ "
\n",
+ " \n",
+ " | x_noise1 | \n",
+ " 0.007743 | \n",
+ " 1.000000 | \n",
+ " -0.006635 | \n",
+ " -0.022209 | \n",
+ " 0.136772 | \n",
+ " 0.002129 | \n",
+ " -0.006972 | \n",
+ " 0.012110 | \n",
+ "
\n",
+ " \n",
+ " | x_noise4 | \n",
+ " -0.001231 | \n",
+ " -0.006635 | \n",
+ " 1.000000 | \n",
+ " -0.015972 | \n",
+ " -0.030669 | \n",
+ " 0.474525 | \n",
+ " 0.013082 | \n",
+ " -0.035711 | \n",
+ "
\n",
+ " \n",
+ " | x_noise5 | \n",
+ " -0.047812 | \n",
+ " -0.022209 | \n",
+ " -0.015972 | \n",
+ " 1.000000 | \n",
+ " 0.006214 | \n",
+ " 0.021844 | \n",
+ " 0.015998 | \n",
+ " -0.001439 | \n",
+ "
\n",
+ " \n",
+ " | x_noise6 | \n",
+ " -0.021563 | \n",
+ " 0.136772 | \n",
+ " -0.030669 | \n",
+ " 0.006214 | \n",
+ " 1.000000 | \n",
+ " 0.029845 | \n",
+ " 0.009551 | \n",
+ " 0.028450 | \n",
+ "
\n",
+ " \n",
+ " | x_noise7 | \n",
+ " 0.015347 | \n",
+ " 0.002129 | \n",
+ " 0.474525 | \n",
+ " 0.021844 | \n",
+ " 0.029845 | \n",
+ " 1.000000 | \n",
+ " 0.023553 | \n",
+ " -0.015334 | \n",
+ "
\n",
+ " \n",
+ " | x_noise9 | \n",
+ " 0.024365 | \n",
+ " -0.006972 | \n",
+ " 0.013082 | \n",
+ " 0.015998 | \n",
+ " 0.009551 | \n",
+ " 0.023553 | \n",
+ " 1.000000 | \n",
+ " -0.030410 | \n",
+ "
\n",
+ " \n",
+ " | x_noise10 | \n",
+ " 0.035477 | \n",
+ " 0.012110 | \n",
+ " -0.035711 | \n",
+ " -0.001439 | \n",
+ " 0.028450 | \n",
+ " -0.015334 | \n",
+ " -0.030410 | \n",
+ " 1.000000 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ]
+ },
+ "execution_count": 53,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "execution_count": 53
+ },
+ {
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-09-26T10:09:16.753632Z",
+ "start_time": "2024-09-26T10:09:16.738083Z"
+ }
+ },
+ "cell_type": "code",
+ "source": "",
+ "id": "ae0a0247ad688b42",
+ "outputs": [],
+ "execution_count": 53
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 2
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython2",
+ "version": "2.7.6"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/lightgbmlss/distributions/distribution_utils.py b/lightgbmlss/distributions/distribution_utils.py
index a8e90ee..62ca5ae 100644
--- a/lightgbmlss/distributions/distribution_utils.py
+++ b/lightgbmlss/distributions/distribution_utils.py
@@ -14,6 +14,7 @@
import warnings
+
class DistributionClass:
"""
Generic class that contains general functions for univariate distributions.
@@ -43,6 +44,7 @@ class DistributionClass:
penalize_crossing: bool
Whether to include a penalty term to discourage crossing of expectiles. Only used for Expectile distribution.
"""
+
def __init__(self,
distribution: torch.distributions.Distribution = None,
univariate: bool = True,
@@ -375,51 +377,96 @@ def predict_dist(self,
Predictions.
"""
+ kwargs = dict()
+ if pred_type == "contributions":
+ kwargs["pred_contrib"] = True
+ n_outputs_per_dist = data.shape[1] + 1
+ else:
+ n_outputs_per_dist = 1
+
predt = torch.tensor(
- booster.predict(data, raw_score=True),
+ booster.predict(data, raw_score=True, **kwargs),
dtype=torch.float32
- ).reshape(-1, self.n_dist_param)
+ ).reshape(-1, self.n_dist_param * n_outputs_per_dist)
# Set init_score as starting point for each distributional parameter.
init_score_pred = torch.tensor(
- np.ones(shape=(data.shape[0], 1))*start_values,
+ np.ones(shape=(data.shape[0], 1)) * start_values,
dtype=torch.float32
)
- # The predictions don't include the init_score specified in creating the train data.
- # Hence, it needs to be added manually with the corresponding transform for each distributional parameter.
- dist_params_predt = np.concatenate(
- [
- response_fun(
- predt[:, i].reshape(-1, 1) + init_score_pred[:, i].reshape(-1, 1)).numpy()
- for i, (dist_param, response_fun) in enumerate(self.param_dict.items())
- ],
- axis=1,
- )
- dist_params_predt = pd.DataFrame(dist_params_predt)
- dist_params_predt.columns = self.param_dict.keys()
-
- # Draw samples from predicted response distribution
- pred_samples_df = self.draw_samples(predt_params=dist_params_predt,
- n_samples=n_samples,
- seed=seed)
+ if pred_type == "contributions":
+ CONST_COL = "Const"
+ COLUMN_LEVELS = ["parameters", "feature_contributions"]
+
+ feature_columns = data.columns.tolist() + [CONST_COL]
+ contributions_predt = pd.DataFrame(
+ predt,
+ columns=pd.MultiIndex.from_product(
+ [self.distribution_arg_names, feature_columns],
+ names=COLUMN_LEVELS
+ ),
+ index=data.index,
+ )
- if pred_type == "parameters":
- return dist_params_predt
+ init_score_pred_df = pd.DataFrame(
+ init_score_pred,
+ columns=pd.MultiIndex.from_product(
+ [self.distribution_arg_names, ["Const"]],
+ names=COLUMN_LEVELS
+ ),
+ index=data.index
+ )
+ contributions_predt[init_score_pred_df.columns] = (
+ contributions_predt[init_score_pred_df.columns] + init_score_pred_df
+ )
+ # Cant include response function on individual feature contributions
+ return contributions_predt
+ else:
+ # The predictions don't include the init_score specified in creating the train data.
+ # Hence, it needs to be added manually with the corresponding transform for each distributional parameter.
+ dist_params_predt = np.concatenate(
+ [
+ response_fun(
+ predt[:, i].reshape(-1, 1) + init_score_pred[:, i].reshape(-1, 1)).numpy()
+ for i, (dist_param, response_fun) in enumerate(self.param_dict.items())
+ ],
+ axis=1,
+ )
+ dist_params_predt = pd.DataFrame(
+ index=data.index,
+ data=dist_params_predt,
+ columns=pd.Index(
+ self.param_dict.keys(),
+ name=pred_type if pred_type == "expectiles" else "parameters"
+ )
+ )
- elif pred_type == "expectiles":
- return dist_params_predt
+ if pred_type == "parameters":
+ return dist_params_predt
- elif pred_type == "samples":
- return pred_samples_df
+ elif pred_type == "expectiles":
+ return dist_params_predt
+ else:
- elif pred_type == "quantiles":
- # Calculate quantiles from predicted response distribution
- pred_quant_df = pred_samples_df.quantile(quantiles, axis=1).T
- pred_quant_df.columns = [str("quant_") + str(quantiles[i]) for i in range(len(quantiles))]
- if self.discrete:
- pred_quant_df = pred_quant_df.astype(int)
- return pred_quant_df
+ # Draw samples from predicted response distribution
+ pred_samples_df = self.draw_samples(predt_params=dist_params_predt,
+ n_samples=n_samples,
+ seed=seed)
+ pred_samples_df.columns.name = "samples"
+ if pred_type == "samples":
+ return pred_samples_df
+
+ elif pred_type == "quantiles":
+ # Calculate quantiles from predicted response distribution
+ pred_quant_df = pred_samples_df.quantile(quantiles, axis=1).T
+ pred_quant_df.columns = [str("quant_") + str(quantiles[i]) for i in range(len(quantiles))]
+ if self.discrete:
+ pred_quant_df = pred_quant_df.astype(int)
+ pred_quant_df.columns.name = "quantiles"
+ return pred_quant_df
+ else:
+ raise RuntimeError(f"{pred_type=} not supported")
def compute_gradients_and_hessians(self,
loss: torch.tensor,
@@ -635,7 +682,7 @@ def dist_select(self,
try:
loss, params = dist_sel.calculate_start_values(target=target.reshape(-1, 1), max_iter=max_iter)
fit_df = pd.DataFrame.from_dict(
- {self.loss_fn: loss.reshape(-1,),
+ {self.loss_fn: loss.reshape(-1, ),
"distribution": str(dist_name),
"params": [params]
}
diff --git a/lightgbmlss/model.py b/lightgbmlss/model.py
index 33896ce..9a006bc 100644
--- a/lightgbmlss/model.py
+++ b/lightgbmlss/model.py
@@ -452,6 +452,7 @@ def predict(self,
- "quantiles" calculates the quantiles from the predicted distribution.
- "parameters" returns the predicted distributional parameters.
- "expectiles" returns the predicted expectiles.
+ - "contributions" returns constibutions of each feature and a constant by calling booster.predict(pred_contrib=True)
n_samples : int
Number of samples to draw from the predicted distribution.
quantiles : List[float]
diff --git a/tests/test_model/test_model.py b/tests/test_model/test_model.py
index c1f12a1..60cd6a6 100644
--- a/tests/test_model/test_model.py
+++ b/tests/test_model/test_model.py
@@ -1,3 +1,6 @@
+import numpy as np
+import pandas as pd
+
from lightgbmlss.model import *
from lightgbmlss.distributions.Gaussian import *
from lightgbmlss.distributions.Mixture import *
@@ -6,6 +9,7 @@
from lightgbmlss.datasets.data_loader import load_simulated_gaussian_data
import pytest
from pytest import approx
+from lightgbmlss.utils import identity_fn
@pytest.fixture
@@ -109,7 +113,7 @@ def test_model_univ_train_eval(self, univariate_data, univariate_lgblss, univari
# Assertions
assert isinstance(lgblss.booster, lgb.Booster)
- def test_model_hpo(self, univariate_data, univariate_lgblss,):
+ def test_model_hpo(self, univariate_data, univariate_lgblss, ):
# Unpack
dtrain, _, _, _ = univariate_data
lgblss = univariate_lgblss
@@ -155,6 +159,7 @@ def test_model_predict(self, univariate_data, univariate_lgblss, univariate_para
pred_params = lgblss.predict(X_test, pred_type="parameters")
pred_samples = lgblss.predict(X_test, pred_type="samples", n_samples=n_samples)
pred_quantiles = lgblss.predict(X_test, pred_type="quantiles", quantiles=quantiles)
+ pred_contributions = lgblss.predict(X_test, pred_type="contributions")
# Assertions
assert isinstance(pred_params, (pd.DataFrame, type(None)))
@@ -162,16 +167,41 @@ def test_model_predict(self, univariate_data, univariate_lgblss, univariate_para
assert not np.isinf(pred_params).any().any()
assert pred_params.shape[1] == lgblss.dist.n_dist_param
assert approx(pred_params["loc"].mean(), abs=0.2) == 10.0
+ assert pred_params.columns.name == "parameters"
assert isinstance(pred_samples, (pd.DataFrame, type(None)))
assert not pred_samples.isna().any().any()
assert not np.isinf(pred_samples).any().any()
assert pred_samples.shape[1] == n_samples
+ assert pred_samples.columns.name == "samples"
assert isinstance(pred_quantiles, (pd.DataFrame, type(None)))
assert not pred_quantiles.isna().any().any()
assert not np.isinf(pred_quantiles).any().any()
assert pred_quantiles.shape[1] == len(quantiles)
+ assert pred_quantiles.columns.name == "quantiles"
+
+ assert isinstance(pred_contributions, (pd.DataFrame, type(None)))
+ assert not pred_contributions.isna().any().any()
+ assert not np.isinf(pred_contributions).any().any()
+ assert (pred_contributions.shape[1] ==
+ lgblss.dist.n_dist_param * (X_test.shape[1] + 1)
+ )
+
+ assert pred_contributions.columns.names == ["parameters", "feature_contributions"]
+
+ for key, response_func in lgblss.dist.param_dict.items():
+ # Sum contributions for each parameter and apply response function
+ pred_contributions_combined = (
+ pd.Series(response_func(
+ torch.tensor(
+ pred_contributions.xs(key, level="parameters", axis=1).sum(axis=1).values)
+ )))
+ assert np.allclose(
+ pred_contributions_combined,
+ pred_params[key], atol=1e-5
+ )
+
def test_model_plot(self, univariate_data, univariate_lgblss, univariate_params):
# Unpack