diff --git a/docs/examples/Predict_Contributions.ipynb b/docs/examples/Predict_Contributions.ipynb new file mode 100644 index 0000000..57ecfec --- /dev/null +++ b/docs/examples/Predict_Contributions.ipynb @@ -0,0 +1,1158 @@ +{ + "cells": [ + { + "metadata": {}, + "cell_type": "markdown", + "source": [ + "# Show how to extract prediction contributions for each distribution parameter\n", + "\n", + "This example shows how to get the contribution of every feature for each distributional parameter for a given data set. This allows similar inferences as one might get from SHAP but comes from lightGBM's internal workings. We can use output for example to get for a given prediction which features are causing the most impact to a given distributional parameter.\n", + "\n", + "These contributions are created before the response function is applied. As such in the case of the identity function, for a given row of data the sum of the contributions should equal the parameter value.\n" + ], + "id": "bf95ab4267d5a34" + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": [ + "# Imports\n", + "\n", + "First, we import necessary functions. " + ], + "id": "bbea43740b87eb" + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-09-26T10:08:32.499517Z", + "start_time": "2024-09-26T10:08:32.484942Z" + } + }, + "cell_type": "code", + "source": [ + "import numpy as np\n", + "\n", + "from lightgbmlss.model import *\n", + "from lightgbmlss.distributions.Gaussian import *\n", + "from lightgbmlss.datasets.data_loader import load_simulated_gaussian_data\n", + "from scipy.stats import norm\n", + "\n", + "import plotnine\n", + "from plotnine import *\n", + "\n", + "plotnine.options.figure_size = (12, 8)" + ], + "id": "b5f2d07ce70bb24b", + "outputs": [], + "execution_count": 45 + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "# Data", + "id": "bd7bba77a5e0fa2f" + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-09-26T10:08:32.595603Z", + "start_time": "2024-09-26T10:08:32.563920Z" + } + }, + "cell_type": "code", + "source": [ + "train, test = load_simulated_gaussian_data()\n", + "\n", + "X_train, y_train = train.filter(regex=\"x\"), train[\"y\"].values\n", + "X_test, y_test = test.filter(regex=\"x\"), test[\"y\"].values\n", + "\n", + "dtrain = lgb.Dataset(X_train, label=y_train)" + ], + "id": "1062b4b851a12bc9", + "outputs": [], + "execution_count": 46 + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": [ + "# Get a Trained Model\n", + "\n", + "As this example is about th uses of a trained model, we wont do any hyper-parameter searching. We will also use a Gaussian distribution as the response function of the loc parameter is the identity function, this will allow us to more easily compare the outputs of a standard parameter prediction to a contributions prediction." + ], + "id": "170feafe1dccf85c" + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-09-26T10:09:16.586326Z", + "start_time": "2024-09-26T10:08:32.595603Z" + } + }, + "cell_type": "code", + "source": [ + "lgblss = LightGBMLSS(\n", + " Gaussian()\n", + ")\n", + "lgblss.train(\n", + " params=dict(),\n", + " train_set=dtrain\n", + ")\n", + "\n", + "param_dict = {\n", + " \"eta\": [\"float\", {\"low\": 1e-5, \"high\": 1, \"log\": True}],\n", + " \"max_depth\": [\"int\", {\"low\": 1, \"high\": 10, \"log\": False}],\n", + " \"num_leaves\": [\"int\", {\"low\": 255, \"high\": 255, \"log\": False}], # set to constant for this example\n", + " \"min_data_in_leaf\": [\"int\", {\"low\": 20, \"high\": 20, \"log\": False}], # set to constant for this example\n", + " \"min_gain_to_split\": [\"float\", {\"low\": 1e-8, \"high\": 40, \"log\": False}],\n", + " \"min_sum_hessian_in_leaf\": [\"float\", {\"low\": 1e-8, \"high\": 500, \"log\": True}],\n", + " \"subsample\": [\"float\", {\"low\": 0.2, \"high\": 1.0, \"log\": False}],\n", + " \"feature_fraction\": [\"float\", {\"low\": 0.2, \"high\": 1.0, \"log\": False}],\n", + " \"boosting\": [\"categorical\", [\"gbdt\"]],\n", + "}\n", + "\n", + "np.random.seed(123)\n", + "opt_param = lgblss.hyper_opt(param_dict,\n", + " dtrain,\n", + " num_boost_round=100, # Number of boosting iterations.\n", + " nfold=5, # Number of cv-folds.\n", + " early_stopping_rounds=20, # Number of early-stopping rounds\n", + " max_minutes=10, # Time budget in minutes, i.e., stop study after the given number of minutes.\n", + " n_trials=30 , # The number of trials. If this argument is set to None, there is no limitation on the number of trials.\n", + " silence=True, # Controls the verbosity of the trail, i.e., user can silence the outputs of the trail.\n", + " seed=123, # Seed used to generate cv-folds.\n", + " hp_seed=123 # Seed for random number generator used in the Bayesian hyperparameter search.\n", + " )\n", + "\n", + "np.random.seed(123)\n", + "\n", + "opt_params = opt_param.copy()\n", + "n_rounds = opt_params[\"opt_rounds\"]\n", + "del opt_params[\"opt_rounds\"]\n", + "\n", + "# Train Model with optimized hyperparameters\n", + "lgblss.train(opt_params,\n", + " dtrain,\n", + " num_boost_round=n_rounds\n", + " )\n" + ], + "id": "f45c868160f1f08b", + "outputs": [ + { + "data": { + "text/plain": [ + " 0%| | 0/30 [00:00\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
FeatureContributionx_truex_noise1x_noise2x_noise3x_noise4x_noise5x_noise6x_noise7x_noise8x_noise9x_noise10Const
00.00.00.00.00.00.00.00.00.00.00.09.979578
10.00.00.00.00.00.00.00.00.00.00.09.979578
20.00.00.00.00.00.00.00.00.00.00.09.979578
30.00.00.00.00.00.00.00.00.00.00.09.979578
40.00.00.00.00.00.00.00.00.00.00.09.979578
.......................................
29950.00.00.00.00.00.00.00.00.00.00.09.979578
29960.00.00.00.00.00.00.00.00.00.00.09.979578
29970.00.00.00.00.00.00.00.00.00.00.09.979578
29980.00.00.00.00.00.00.00.00.00.00.09.979578
29990.00.00.00.00.00.00.00.00.00.00.09.979578
\n", + "

3000 rows × 12 columns

\n", + "" + ] + }, + "execution_count": 50, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 50 + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "### Show contributions for each feature for scale parameter", + "id": "eaf2ad3ecc736152" + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-09-26T10:09:16.706240Z", + "start_time": "2024-09-26T10:09:16.673598Z" + } + }, + "cell_type": "code", + "source": "pred_param_contributions.xs(\"scale\", axis=1, level=\"distribution_arg\")\n", + "id": "c5453f7e5a378096", + "outputs": [ + { + "data": { + "text/plain": [ + "FeatureContribution x_true x_noise1 x_noise2 x_noise3 x_noise4 \\\n", + "0 0.410556 0.002107 0.0 0.0 0.000034 \n", + "1 0.411267 0.000683 0.0 0.0 -0.000340 \n", + "2 -0.597710 0.002107 0.0 0.0 -0.000340 \n", + "3 0.848812 0.002835 0.0 0.0 0.000034 \n", + "4 0.414533 0.001566 0.0 0.0 0.000867 \n", + "... ... ... ... ... ... \n", + "2995 0.411235 0.002835 0.0 0.0 -0.000340 \n", + "2996 0.380668 0.002107 0.0 0.0 -0.000340 \n", + "2997 -0.597620 0.001648 0.0 0.0 0.000034 \n", + "2998 -0.607374 -0.001427 0.0 0.0 -0.001144 \n", + "2999 0.410556 0.002107 0.0 0.0 0.000034 \n", + "\n", + "FeatureContribution x_noise5 x_noise6 x_noise7 x_noise8 x_noise9 \\\n", + "0 0.000197 0.004104 -0.000126 0.0 -0.000608 \n", + "1 0.000197 0.004816 -0.000126 0.0 -0.000608 \n", + "2 0.000197 0.004104 -0.000126 0.0 -0.000608 \n", + "3 0.000197 0.001399 -0.000126 0.0 -0.000608 \n", + "4 0.000123 0.002717 -0.004173 0.0 0.053938 \n", + "... ... ... ... ... ... \n", + "2995 0.000197 0.002135 -0.000126 0.0 -0.000608 \n", + "2996 0.000197 0.004402 -0.000126 0.0 -0.000608 \n", + "2997 0.000197 -0.004548 -0.000126 0.0 -0.000700 \n", + "2998 0.000888 0.002017 0.003200 0.0 -0.000029 \n", + "2999 0.000197 0.004104 -0.000126 0.0 -0.000608 \n", + "\n", + "FeatureContribution x_noise10 Const \n", + "0 -0.000503 0.653625 \n", + "1 -0.000130 0.653625 \n", + "2 -0.000130 0.653625 \n", + "3 0.001530 0.653625 \n", + "4 0.001895 0.653625 \n", + "... ... ... \n", + "2995 0.000432 0.653625 \n", + "2996 0.002378 0.653625 \n", + "2997 -0.000892 0.653625 \n", + "2998 -0.004399 0.653625 \n", + "2999 -0.000503 0.653625 \n", + "\n", + "[3000 rows x 12 columns]" + ], + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
FeatureContributionx_truex_noise1x_noise2x_noise3x_noise4x_noise5x_noise6x_noise7x_noise8x_noise9x_noise10Const
00.4105560.0021070.00.00.0000340.0001970.004104-0.0001260.0-0.000608-0.0005030.653625
10.4112670.0006830.00.0-0.0003400.0001970.004816-0.0001260.0-0.000608-0.0001300.653625
2-0.5977100.0021070.00.0-0.0003400.0001970.004104-0.0001260.0-0.000608-0.0001300.653625
30.8488120.0028350.00.00.0000340.0001970.001399-0.0001260.0-0.0006080.0015300.653625
40.4145330.0015660.00.00.0008670.0001230.002717-0.0041730.00.0539380.0018950.653625
.......................................
29950.4112350.0028350.00.0-0.0003400.0001970.002135-0.0001260.0-0.0006080.0004320.653625
29960.3806680.0021070.00.0-0.0003400.0001970.004402-0.0001260.0-0.0006080.0023780.653625
2997-0.5976200.0016480.00.00.0000340.000197-0.004548-0.0001260.0-0.000700-0.0008920.653625
2998-0.607374-0.0014270.00.0-0.0011440.0008880.0020170.0032000.0-0.000029-0.0043990.653625
29990.4105560.0021070.00.00.0000340.0001970.004104-0.0001260.0-0.000608-0.0005030.653625
\n", + "

3000 rows × 12 columns

\n", + "
" + ] + }, + "execution_count": 51, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 51 + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "# Show Mean Feature Impact for Data set", + "id": "394e64d247168fa0" + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-09-26T10:09:16.722325Z", + "start_time": "2024-09-26T10:09:16.706781Z" + } + }, + "cell_type": "code", + "source": [ + "sum_of_contributions_column = \"SumOfContributions\"\n", + "mean_parameter_contribution = pred_param_contributions.abs().mean().unstack(\"distribution_arg\")\n", + "mean_parameter_contribution[sum_of_contributions_column] = mean_parameter_contribution.sum(1)\n", + "\n", + "mean_parameter_contribution.sort_values(sum_of_contributions_column, ascending=False).drop(columns=sum_of_contributions_column)\n" + ], + "id": "54d4970cf1957735", + "outputs": [ + { + "data": { + "text/plain": [ + "distribution_arg loc scale\n", + "FeatureContribution \n", + "Const 9.979577 0.653625\n", + "x_true 0.000000 0.591884\n", + "x_noise6 0.000000 0.004868\n", + "x_noise7 0.000000 0.004415\n", + "x_noise1 0.000000 0.003994\n", + "x_noise10 0.000000 0.002689\n", + "x_noise9 0.000000 0.002583\n", + "x_noise4 0.000000 0.001668\n", + "x_noise5 0.000000 0.000585\n", + "x_noise2 0.000000 0.000000\n", + "x_noise3 0.000000 0.000000\n", + "x_noise8 0.000000 0.000000" + ], + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
distribution_arglocscale
FeatureContribution
Const9.9795770.653625
x_true0.0000000.591884
x_noise60.0000000.004868
x_noise70.0000000.004415
x_noise10.0000000.003994
x_noise100.0000000.002689
x_noise90.0000000.002583
x_noise40.0000000.001668
x_noise50.0000000.000585
x_noise20.0000000.000000
x_noise30.0000000.000000
x_noise80.0000000.000000
\n", + "
" + ] + }, + "execution_count": 52, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 52 + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "# Get correlation between contributions for the scale parameter ", + "id": "f7c73f303f04d4ff" + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "", + "id": "8d5dc9e448d5c322" + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-09-26T10:09:16.737953Z", + "start_time": "2024-09-26T10:09:16.722325Z" + } + }, + "cell_type": "code", + "source": "pred_param_contributions.xs(\"scale\", axis=1, level=\"distribution_arg\").corr().dropna(how=\"all\").dropna(axis=1,how=\"all\")\n", + "id": "f331d8603042908", + "outputs": [ + { + "data": { + "text/plain": [ + "FeatureContribution x_true x_noise1 x_noise4 x_noise5 x_noise6 \\\n", + "FeatureContribution \n", + "x_true 1.000000 0.007743 -0.001231 -0.047812 -0.021563 \n", + "x_noise1 0.007743 1.000000 -0.006635 -0.022209 0.136772 \n", + "x_noise4 -0.001231 -0.006635 1.000000 -0.015972 -0.030669 \n", + "x_noise5 -0.047812 -0.022209 -0.015972 1.000000 0.006214 \n", + "x_noise6 -0.021563 0.136772 -0.030669 0.006214 1.000000 \n", + "x_noise7 0.015347 0.002129 0.474525 0.021844 0.029845 \n", + "x_noise9 0.024365 -0.006972 0.013082 0.015998 0.009551 \n", + "x_noise10 0.035477 0.012110 -0.035711 -0.001439 0.028450 \n", + "\n", + "FeatureContribution x_noise7 x_noise9 x_noise10 \n", + "FeatureContribution \n", + "x_true 0.015347 0.024365 0.035477 \n", + "x_noise1 0.002129 -0.006972 0.012110 \n", + "x_noise4 0.474525 0.013082 -0.035711 \n", + "x_noise5 0.021844 0.015998 -0.001439 \n", + "x_noise6 0.029845 0.009551 0.028450 \n", + "x_noise7 1.000000 0.023553 -0.015334 \n", + "x_noise9 0.023553 1.000000 -0.030410 \n", + "x_noise10 -0.015334 -0.030410 1.000000 " + ], + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
FeatureContributionx_truex_noise1x_noise4x_noise5x_noise6x_noise7x_noise9x_noise10
FeatureContribution
x_true1.0000000.007743-0.001231-0.047812-0.0215630.0153470.0243650.035477
x_noise10.0077431.000000-0.006635-0.0222090.1367720.002129-0.0069720.012110
x_noise4-0.001231-0.0066351.000000-0.015972-0.0306690.4745250.013082-0.035711
x_noise5-0.047812-0.022209-0.0159721.0000000.0062140.0218440.015998-0.001439
x_noise6-0.0215630.136772-0.0306690.0062141.0000000.0298450.0095510.028450
x_noise70.0153470.0021290.4745250.0218440.0298451.0000000.023553-0.015334
x_noise90.024365-0.0069720.0130820.0159980.0095510.0235531.000000-0.030410
x_noise100.0354770.012110-0.035711-0.0014390.028450-0.015334-0.0304101.000000
\n", + "
" + ] + }, + "execution_count": 53, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 53 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-09-26T10:09:16.753632Z", + "start_time": "2024-09-26T10:09:16.738083Z" + } + }, + "cell_type": "code", + "source": "", + "id": "ae0a0247ad688b42", + "outputs": [], + "execution_count": 53 + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/lightgbmlss/distributions/distribution_utils.py b/lightgbmlss/distributions/distribution_utils.py index a8e90ee..62ca5ae 100644 --- a/lightgbmlss/distributions/distribution_utils.py +++ b/lightgbmlss/distributions/distribution_utils.py @@ -14,6 +14,7 @@ import warnings + class DistributionClass: """ Generic class that contains general functions for univariate distributions. @@ -43,6 +44,7 @@ class DistributionClass: penalize_crossing: bool Whether to include a penalty term to discourage crossing of expectiles. Only used for Expectile distribution. """ + def __init__(self, distribution: torch.distributions.Distribution = None, univariate: bool = True, @@ -375,51 +377,96 @@ def predict_dist(self, Predictions. """ + kwargs = dict() + if pred_type == "contributions": + kwargs["pred_contrib"] = True + n_outputs_per_dist = data.shape[1] + 1 + else: + n_outputs_per_dist = 1 + predt = torch.tensor( - booster.predict(data, raw_score=True), + booster.predict(data, raw_score=True, **kwargs), dtype=torch.float32 - ).reshape(-1, self.n_dist_param) + ).reshape(-1, self.n_dist_param * n_outputs_per_dist) # Set init_score as starting point for each distributional parameter. init_score_pred = torch.tensor( - np.ones(shape=(data.shape[0], 1))*start_values, + np.ones(shape=(data.shape[0], 1)) * start_values, dtype=torch.float32 ) - # The predictions don't include the init_score specified in creating the train data. - # Hence, it needs to be added manually with the corresponding transform for each distributional parameter. - dist_params_predt = np.concatenate( - [ - response_fun( - predt[:, i].reshape(-1, 1) + init_score_pred[:, i].reshape(-1, 1)).numpy() - for i, (dist_param, response_fun) in enumerate(self.param_dict.items()) - ], - axis=1, - ) - dist_params_predt = pd.DataFrame(dist_params_predt) - dist_params_predt.columns = self.param_dict.keys() - - # Draw samples from predicted response distribution - pred_samples_df = self.draw_samples(predt_params=dist_params_predt, - n_samples=n_samples, - seed=seed) + if pred_type == "contributions": + CONST_COL = "Const" + COLUMN_LEVELS = ["parameters", "feature_contributions"] + + feature_columns = data.columns.tolist() + [CONST_COL] + contributions_predt = pd.DataFrame( + predt, + columns=pd.MultiIndex.from_product( + [self.distribution_arg_names, feature_columns], + names=COLUMN_LEVELS + ), + index=data.index, + ) - if pred_type == "parameters": - return dist_params_predt + init_score_pred_df = pd.DataFrame( + init_score_pred, + columns=pd.MultiIndex.from_product( + [self.distribution_arg_names, ["Const"]], + names=COLUMN_LEVELS + ), + index=data.index + ) + contributions_predt[init_score_pred_df.columns] = ( + contributions_predt[init_score_pred_df.columns] + init_score_pred_df + ) + # Cant include response function on individual feature contributions + return contributions_predt + else: + # The predictions don't include the init_score specified in creating the train data. + # Hence, it needs to be added manually with the corresponding transform for each distributional parameter. + dist_params_predt = np.concatenate( + [ + response_fun( + predt[:, i].reshape(-1, 1) + init_score_pred[:, i].reshape(-1, 1)).numpy() + for i, (dist_param, response_fun) in enumerate(self.param_dict.items()) + ], + axis=1, + ) + dist_params_predt = pd.DataFrame( + index=data.index, + data=dist_params_predt, + columns=pd.Index( + self.param_dict.keys(), + name=pred_type if pred_type == "expectiles" else "parameters" + ) + ) - elif pred_type == "expectiles": - return dist_params_predt + if pred_type == "parameters": + return dist_params_predt - elif pred_type == "samples": - return pred_samples_df + elif pred_type == "expectiles": + return dist_params_predt + else: - elif pred_type == "quantiles": - # Calculate quantiles from predicted response distribution - pred_quant_df = pred_samples_df.quantile(quantiles, axis=1).T - pred_quant_df.columns = [str("quant_") + str(quantiles[i]) for i in range(len(quantiles))] - if self.discrete: - pred_quant_df = pred_quant_df.astype(int) - return pred_quant_df + # Draw samples from predicted response distribution + pred_samples_df = self.draw_samples(predt_params=dist_params_predt, + n_samples=n_samples, + seed=seed) + pred_samples_df.columns.name = "samples" + if pred_type == "samples": + return pred_samples_df + + elif pred_type == "quantiles": + # Calculate quantiles from predicted response distribution + pred_quant_df = pred_samples_df.quantile(quantiles, axis=1).T + pred_quant_df.columns = [str("quant_") + str(quantiles[i]) for i in range(len(quantiles))] + if self.discrete: + pred_quant_df = pred_quant_df.astype(int) + pred_quant_df.columns.name = "quantiles" + return pred_quant_df + else: + raise RuntimeError(f"{pred_type=} not supported") def compute_gradients_and_hessians(self, loss: torch.tensor, @@ -635,7 +682,7 @@ def dist_select(self, try: loss, params = dist_sel.calculate_start_values(target=target.reshape(-1, 1), max_iter=max_iter) fit_df = pd.DataFrame.from_dict( - {self.loss_fn: loss.reshape(-1,), + {self.loss_fn: loss.reshape(-1, ), "distribution": str(dist_name), "params": [params] } diff --git a/lightgbmlss/model.py b/lightgbmlss/model.py index 33896ce..9a006bc 100644 --- a/lightgbmlss/model.py +++ b/lightgbmlss/model.py @@ -452,6 +452,7 @@ def predict(self, - "quantiles" calculates the quantiles from the predicted distribution. - "parameters" returns the predicted distributional parameters. - "expectiles" returns the predicted expectiles. + - "contributions" returns constibutions of each feature and a constant by calling booster.predict(pred_contrib=True) n_samples : int Number of samples to draw from the predicted distribution. quantiles : List[float] diff --git a/tests/test_model/test_model.py b/tests/test_model/test_model.py index c1f12a1..60cd6a6 100644 --- a/tests/test_model/test_model.py +++ b/tests/test_model/test_model.py @@ -1,3 +1,6 @@ +import numpy as np +import pandas as pd + from lightgbmlss.model import * from lightgbmlss.distributions.Gaussian import * from lightgbmlss.distributions.Mixture import * @@ -6,6 +9,7 @@ from lightgbmlss.datasets.data_loader import load_simulated_gaussian_data import pytest from pytest import approx +from lightgbmlss.utils import identity_fn @pytest.fixture @@ -109,7 +113,7 @@ def test_model_univ_train_eval(self, univariate_data, univariate_lgblss, univari # Assertions assert isinstance(lgblss.booster, lgb.Booster) - def test_model_hpo(self, univariate_data, univariate_lgblss,): + def test_model_hpo(self, univariate_data, univariate_lgblss, ): # Unpack dtrain, _, _, _ = univariate_data lgblss = univariate_lgblss @@ -155,6 +159,7 @@ def test_model_predict(self, univariate_data, univariate_lgblss, univariate_para pred_params = lgblss.predict(X_test, pred_type="parameters") pred_samples = lgblss.predict(X_test, pred_type="samples", n_samples=n_samples) pred_quantiles = lgblss.predict(X_test, pred_type="quantiles", quantiles=quantiles) + pred_contributions = lgblss.predict(X_test, pred_type="contributions") # Assertions assert isinstance(pred_params, (pd.DataFrame, type(None))) @@ -162,16 +167,41 @@ def test_model_predict(self, univariate_data, univariate_lgblss, univariate_para assert not np.isinf(pred_params).any().any() assert pred_params.shape[1] == lgblss.dist.n_dist_param assert approx(pred_params["loc"].mean(), abs=0.2) == 10.0 + assert pred_params.columns.name == "parameters" assert isinstance(pred_samples, (pd.DataFrame, type(None))) assert not pred_samples.isna().any().any() assert not np.isinf(pred_samples).any().any() assert pred_samples.shape[1] == n_samples + assert pred_samples.columns.name == "samples" assert isinstance(pred_quantiles, (pd.DataFrame, type(None))) assert not pred_quantiles.isna().any().any() assert not np.isinf(pred_quantiles).any().any() assert pred_quantiles.shape[1] == len(quantiles) + assert pred_quantiles.columns.name == "quantiles" + + assert isinstance(pred_contributions, (pd.DataFrame, type(None))) + assert not pred_contributions.isna().any().any() + assert not np.isinf(pred_contributions).any().any() + assert (pred_contributions.shape[1] == + lgblss.dist.n_dist_param * (X_test.shape[1] + 1) + ) + + assert pred_contributions.columns.names == ["parameters", "feature_contributions"] + + for key, response_func in lgblss.dist.param_dict.items(): + # Sum contributions for each parameter and apply response function + pred_contributions_combined = ( + pd.Series(response_func( + torch.tensor( + pred_contributions.xs(key, level="parameters", axis=1).sum(axis=1).values) + ))) + assert np.allclose( + pred_contributions_combined, + pred_params[key], atol=1e-5 + ) + def test_model_plot(self, univariate_data, univariate_lgblss, univariate_params): # Unpack