Add a visualisation utility to plot the predicted distribution (regression)#987
Add a visualisation utility to plot the predicted distribution (regression)#987eliott-kalfon wants to merge 4 commits into
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces visualization utilities for TabPFN, specifically adding the plot_regression_distribution function to plot predicted target distributions for single samples, along with an example script and a new viz optional dependency. The review feedback is highly constructive and should be addressed to improve robustness and compatibility. Key recommendations include ensuring backward compatibility with Matplotlib versions prior to 3.7.0, adding input validation for the plotting parameters, capping the smoothing filter size to prevent out-of-bounds issues, and handling cases where the visible density is zero to avoid y-axis rendering warnings.
| # Local import because matplotlib is an optional dependency. | ||
| try: | ||
| import matplotlib.pyplot as plt # noqa: PLC0415 | ||
| from matplotlib.patches import Patch # noqa: PLC0415 | ||
| except ModuleNotFoundError as err: | ||
| raise ModuleNotFoundError( | ||
| "matplotlib is required for plotting. " | ||
| 'Install it with `pip install "tabpfn[viz]"`' | ||
| ) from err |
There was a problem hiding this comment.
To prevent unexpected runtime errors and ensure robust defensive programming, we should validate the input arguments (quantile_interval, zoom_quantile, smooth, and statistics) at the beginning of the function before performing any computations or plotting.
if quantile_interval is not None:
if len(quantile_interval) != 2:
raise ValueError("quantile_interval must be a tuple of length 2.")
if not all(0 <= q <= 1 for q in quantile_interval):
raise ValueError("quantile_interval values must be between 0 and 1.")
if quantile_interval[0] > quantile_interval[1]:
raise ValueError("quantile_interval must be in ascending order (low, high).")
if zoom_quantile is not None and not (0 < zoom_quantile <= 1):
raise ValueError("zoom_quantile must be in the range (0, 1].")
if smooth < 0:
raise ValueError("smooth must be non-negative.")
for name in statistics:
if name not in _STAT_STYLES:
raise ValueError(
f"Unknown statistic '{name}'. Supported statistics are: {list(_STAT_STYLES.keys())}"
)
# Local import because matplotlib is an optional dependency.
try:
import matplotlib.pyplot as plt # noqa: PLC0415
from matplotlib.patches import Patch # noqa: PLC0415
except ModuleNotFoundError as err:
raise ModuleNotFoundError(
"matplotlib is required for plotting. "
'Install it with `pip install "tabpfn[viz]"`'
) from errThere was a problem hiding this comment.
Not sure about this, as it may bloat the codebase. I am leaving the comment open, to be considered by the reviewer. Thank you!
There was a problem hiding this comment.
I think these checks would be fair to add, but not essential. Also serve as documentation.
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
There was a problem hiding this comment.
Thanks @eliott-kalfon , looks great!
Just a suggestion for the API and a few nits.
Also, let's move to british spelling -> "visualisation" everywhere.
| # Local import because matplotlib is an optional dependency. | ||
| try: | ||
| import matplotlib.pyplot as plt # noqa: PLC0415 | ||
| from matplotlib.patches import Patch # noqa: PLC0415 | ||
| except ModuleNotFoundError as err: | ||
| raise ModuleNotFoundError( | ||
| "matplotlib is required for plotting. " | ||
| 'Install it with `pip install "tabpfn[viz]"`' | ||
| ) from err |
There was a problem hiding this comment.
I think these checks would be fair to add, but not essential. Also serve as documentation.
|
|
||
| def plot_regression_distribution( | ||
| regressor: TabPFNRegressor, | ||
| x: object, |
There was a problem hiding this comment.
Not sure about the interface of this one. Let's say the user does not only want to visualize it but also want's to do something else with the output of the prediction.
What about not passing the full regressor and x and predict inside this plotting function, and rather just passing the output of regressor.predict(..., output_type="full")?
Might need a small input validation as well.
| for ax, idx, label in zip( | ||
| axes, | ||
| [low_idx, mid_idx, high_idx], | ||
| ["low prediction", "median prediction", "high prediction"], |
There was a problem hiding this comment.
could you explain in the plot what low, median, and high predictions mean?
| [low_idx, mid_idx, high_idx], | ||
| ["low prediction", "median prediction", "high prediction"], | ||
| ): | ||
| plot_regression_distribution(reg, X_test[idx], ax=ax) |
There was a problem hiding this comment.
if you change the interface (see other comment) you can also predict three points in a batch and then just pass the results here, so it's a bit more flexible.
|
|
||
| for name in statistics: | ||
| value = float(np.atleast_1d(out[name])[0]) | ||
| c, ls = _STAT_STYLES[name] |
There was a problem hiding this comment.
would throw a cryptic error if name is not in STAT_STYLES. Maybe raise a nicer error or add types for the modes.
Issue
The goal of this PR is to enable users to visualise the predicted distribution for a given observation in a regression setting.
To the review, some design choices:
TabPFNRegressorAPItabpfn.visualizationmodulevizsubprojectThe objective is to give the user this functionality without adding clutter or dependencies.
Public API Changes
How Has This Been Tested?
Local run and tests
Checklist
changelog/README.md), or "no changelog needed" label requested.