diff --git a/README.md b/README.md index dc87ab6..e0d584a 100644 --- a/README.md +++ b/README.md @@ -6,6 +6,10 @@ xaiflow integrates seamlessly with MLflow to generate interactive HTML reports for SHAP analysis. Instead of static charts and images, you get rich, interactive visualizations that stakeholders can explore and understand. +Here should the video go: +[![xaiflow showcase](video/video_thumbnail.png)](https://github.com/user-attachments/assets/f508fa6f-ab0f-493d-a892-ed958331e30a) +*Click the image above to watch the feature showcase video.* + ## What We're Trying to Achieve Most ML workflows produce explanations as static images or basic charts, which creates several problems: @@ -38,10 +42,9 @@ with mlflow.start_run(): # Add interactive explainable AI reports plugin = XaiflowPlugin() - plugin.log_feature_importance_report( + plugin.log_xai_report( feature_names=X.columns.tolist(), shap_values=shap_values, - report_name="model_explanation.html" ) ``` @@ -81,7 +84,7 @@ xaiflow/ ### Core Components **MLflow Integration** (`mlflow_plugin.py`) -The `CEMLflowPlugin` class handles the integration with MLflow. The main method `log_feature_importance_report()` processes SHAP values, manages feature encodings, and stores the generated reports as MLflow artifacts. +The `CEMLflowPlugin` class handles the integration with MLflow. The main method `log_xai_report()` processes SHAP values, manages feature encodings, and stores the generated reports as MLflow artifacts. **Report Generation** (`report_generator.py`) The `ReportGenerator` class converts SHAP data into interactive HTML reports using Jinja2 templating. It handles template loading, asset bundling, and data injection into the frontend components. @@ -127,11 +130,10 @@ feature_encodings = { 'region': {0: 'North', 1: 'South', 2: 'East', 3: 'West'} } -plugin.log_feature_importance_report( +plugin.log_xai_report( feature_names=feature_names, shap_values=shap_values, feature_encodings=feature_encodings, - report_name="enhanced_report.html" ) ``` diff --git a/examples/notebooks/auto_mpg_example.ipynb b/examples/notebooks/auto_mpg_example.ipynb index 8d05e6e..b9521a9 100644 --- a/examples/notebooks/auto_mpg_example.ipynb +++ b/examples/notebooks/auto_mpg_example.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "id": "79de156f", "metadata": {}, "outputs": [ @@ -10,7 +10,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "/home/tobias/programming/cloudexplain/ce-mlflow-extension/ce-mlflow-extension/.venv/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + "c:\\programming\\cloudexplain\\xflow\\.venv\\Lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] } @@ -42,7 +42,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "id": "0da5d7e2", "metadata": {}, "outputs": [ @@ -50,13 +50,11 @@ "name": "stdout", "output_type": "stream", "text": [ - "Loaded bundle.js content (218107 characters)\n", - "Saved report data to test_report_data.json\n", - "logged to test_report.html\n", - "Feature importance report logged to MLflow: reports/test_report_auto_mpg.html\n", - "Run ID: 72ea715bfe9c42bc840388933f6999a8. If you are running mlflow locally use:\n", + "Loaded bundle.js content (225719 characters)\n", + "Feature importance report logged to MLflow: reports/feature_importance_report.html\n", + "Run ID: 7521c3f260f84a5d8e038a13bc91498b. If you are running mlflow locally use:\n", "python -m mlflow ui --port 5000\n", - "Then open http://localhost:5000/#/experiments/921177506761828334/runs/72ea715bfe9c42bc840388933f6999a8 to view the report.\n" + "Then open http://localhost:5000/#/experiments/557047036753041520/runs/7521c3f260f84a5d8e038a13bc91498b to view the report. Note: it's important to start mlflow in the directory in which you execute the notebook.\n" ] } ], @@ -91,10 +89,9 @@ " feature_encodings = {'cylinders_encoded': {0: '3', 1: '4', 2: '5', 3: '6', 4: '8'},\n", " 'model_encoded': {0: 'Super 70', 1: 'Super 71', 2: 'Low 72', 3: 'Nice 73', 4: 'Great 74', 5: 'Lame 75', 6: 'High 76', 7: '77', 8: '78', 9: '79', 10: '80', 11: '81', 12: '82'},\n", " 'origin_encoded': {0: 'Afghanistan', 1: 'Bangladesh', 2: 'Maui'}}\n", - " artifact_path = plugin.log_feature_importance_report(\n", + " artifact_path = plugin.log_xai_report(\n", " feature_names=list(X.columns),\n", " shap_values=shap_values,\n", - " report_name=\"test_report_auto_mpg.html\",\n", " feature_encodings=feature_encodings\n", " )\n", " run_id = mlflow.active_run().info.run_id\n", @@ -119,7 +116,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.3" + "version": "3.13.5" } }, "nbformat": 4, diff --git a/examples/notebooks/azure_ml_auto_mpg_example.ipynb b/examples/notebooks/azure_ml_auto_mpg_example.ipynb index b2d52c5..37d79e4 100644 --- a/examples/notebooks/azure_ml_auto_mpg_example.ipynb +++ b/examples/notebooks/azure_ml_auto_mpg_example.ipynb @@ -258,10 +258,9 @@ " feature_encodings = {'cylinders_encoded': {0: '3', 1: '4', 2: '5', 3: '6', 4: '8'},\n", " 'model_encoded': {0: 'Super 70', 1: 'Super 71', 2: 'Low 72', 3: 'Nice 73', 4: 'Great 74', 5: 'Lame 75', 6: 'High 76', 7: '77', 8: '78', 9: '79', 10: '80', 11: '81', 12: '82'},\n", " 'origin_encoded': {0: 'Afghanistan', 1: 'Bangladesh', 2: 'Maui'}}\n", - " artifact_path = plugin.log_feature_importance_report(\n", + " artifact_path = plugin.log_xai_report(\n", " feature_names=list(X.columns),\n", " shap_values=shap_values,\n", - " report_name=\"test_report_auto_mpg.html\",\n", " feature_encodings=feature_encodings\n", " )\n", " run_id = mlflow.active_run().info.run_id\n", diff --git a/examples/scripts/auto_mpg_example.py b/examples/scripts/auto_mpg_example.py index 969465f..14b29b2 100644 --- a/examples/scripts/auto_mpg_example.py +++ b/examples/scripts/auto_mpg_example.py @@ -43,11 +43,12 @@ feature_encodings = {'cylinders_encoded': {0: '3', 1: '4', 2: '5', 3: '6', 4: '8'}, 'model_encoded': {0: 'Super 70', 1: 'Super 71', 2: 'Low 72', 3: 'Nice 73', 4: 'Great 74', 5: 'Lame 75', 6: 'High 76', 7: '77', 8: '78', 9: '79', 10: '80', 11: '81', 12: '82'}, 'origin_encoded': {0: 'Afghanistan', 1: 'Bangladesh', 2: 'Maui'}} - artifact_path = plugin.log_feature_importance_report( + artifact_path = plugin.log_xai_report( feature_names=list(X.columns), shap_values=shap_values, - report_name="test_report_auto_mpg.html", - feature_encodings=feature_encodings + feature_encodings=feature_encodings, + # assign each sample to a custom group label + group_labels=["Custom Group " + str(i % 4) for i in range(len(X))], ) run_id = mlflow.active_run().info.run_id print(f"Run ID: {run_id}. If you are running mlflow locally use:\npython -m mlflow ui --port 5000\nThen open http://localhost:5000/#/experiments/{mlflow.get_experiment_by_name(experiment_name).experiment_id}/runs/{run_id} to view the report.", diff --git a/src/xaiflow/mlflow_plugin.py b/src/xaiflow/mlflow_plugin.py index f14a84a..4d3b695 100644 --- a/src/xaiflow/mlflow_plugin.py +++ b/src/xaiflow/mlflow_plugin.py @@ -26,7 +26,7 @@ def __init__(self): self.template_dir = os.path.join(os.path.dirname(__file__), 'templates') self.env = Environment(loader=FileSystemLoader(self.template_dir)) - def log_feature_importance_report( + def log_xai_report( self, feature_names: List[str], shap_values: Explanation, diff --git a/src/xflow.egg-info/PKG-INFO b/src/xflow.egg-info/PKG-INFO index 48a3f07..8826561 100644 --- a/src/xflow.egg-info/PKG-INFO +++ b/src/xflow.egg-info/PKG-INFO @@ -80,7 +80,7 @@ with mlflow.start_run(): # Add interactive explainable AI reports plugin = CEMLflowPlugin() - plugin.log_feature_importance_report( + plugin.log_xai_report( feature_names=X.columns.tolist(), shap_values=shap_values, report_name="model_explanation.html" @@ -121,7 +121,7 @@ xaiflow/ ### Core Components **MLflow Integration** (`mlflow_plugin.py`) -The `CEMLflowPlugin` class handles the integration with MLflow. The main method `log_feature_importance_report()` processes SHAP values, manages feature encodings, and stores the generated reports as MLflow artifacts. +The `CEMLflowPlugin` class handles the integration with MLflow. The main method `log_xai_report()` processes SHAP values, manages feature encodings, and stores the generated reports as MLflow artifacts. **Report Generation** (`report_generator.py`) The `ReportGenerator` class converts SHAP data into interactive HTML reports using Jinja2 templating. It handles template loading, asset bundling, and data injection into the frontend components. @@ -167,7 +167,7 @@ feature_encodings = { 'region': {0: 'North', 1: 'South', 2: 'East', 3: 'West'} } -plugin.log_feature_importance_report( +plugin.log_xai_report( feature_names=feature_names, shap_values=shap_values, feature_encodings=feature_encodings, diff --git a/tests/test_mlflow_plugin.py b/tests/test_mlflow_plugin.py index 9ecd353..ea36b60 100644 --- a/tests/test_mlflow_plugin.py +++ b/tests/test_mlflow_plugin.py @@ -4,8 +4,6 @@ import numpy as np from sklearn.datasets import fetch_openml from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier -from xgboost import XGBClassifier -from catboost import CatBoostClassifier from sklearn.preprocessing import LabelEncoder import shap from typing import Callable @@ -187,7 +185,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): mocker.patch("mlflow.log_artifact") with mlflow.start_run(run_name="auto_mpg_test"): - plugin.log_feature_importance_report( + plugin.log_xai_report( shap_values=shap_values, feature_encodings=feature_encodings, feature_names=list(X.columns), @@ -263,7 +261,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): mocker.patch("mlflow.log_artifact") with mlflow.start_run(run_name="auto_mpg_test"): - plugin.log_feature_importance_report( + plugin.log_xai_report( shap_values=shap_values, feature_encodings=feature_encodings, feature_names=list(X.columns), @@ -319,11 +317,10 @@ def __exit__(self, exc_type, exc_val, exc_tb): mocker.patch("mlflow.log_artifact") with mlflow.start_run(run_name="auto_mpg_test"): - plugin.log_feature_importance_report( + plugin.log_xai_report( shap_values=shap_values, feature_encodings=feature_encodings, feature_names=list(X.columns), group_labels=["Group 1", "Group 2", "Group 3", "Group 4"] * int(len(shap_values) / 4) # Example group labels ) - html_content_click_test(Path(output_path)) - # return html_content \ No newline at end of file + html_content_click_test(Path(output_path)) \ No newline at end of file diff --git a/video/video_thumbnail.png b/video/video_thumbnail.png new file mode 100644 index 0000000..cee81fb Binary files /dev/null and b/video/video_thumbnail.png differ