Skip to content

Add a visualisation utility to plot the predicted distribution (regression)#987

Open
eliott-kalfon wants to merge 4 commits into
mainfrom
eliott/bar-distribution-visual
Open

Add a visualisation utility to plot the predicted distribution (regression)#987
eliott-kalfon wants to merge 4 commits into
mainfrom
eliott/bar-distribution-visual

Conversation

@eliott-kalfon
Copy link
Copy Markdown
Contributor

Issue

The goal of this PR is to enable users to visualise the predicted distribution for a given observation in a regression setting.

To the review, some design choices:

  • No change to the TabPFNRegressor API
  • Creation of a tabpfn.visualization module
  • matplotlib was not added to core dependencies but to the viz subproject

The objective is to give the user this functionality without adding clutter or dependencies.


Public API Changes

  • No Public API changes
  • Yes, Public API changes (Details below)

How Has This Been Tested?

Local run and tests

Checklist

  • The changes have been tested locally.
  • Documentation has been updated (if the public API or usage changes).
  • A changelog entry has been added (see changelog/README.md), or "no changelog needed" label requested.
  • The code follows the project's style guidelines.
  • I have considered the impact of these changes on the public API.

@eliott-kalfon eliott-kalfon requested a review from bejaeger May 29, 2026 14:22
@eliott-kalfon eliott-kalfon requested a review from a team as a code owner May 29, 2026 14:22
@eliott-kalfon eliott-kalfon changed the title Eliott/bar distribution visual Add a visualisation utility to plot the predicted distribution (regression) May 29, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces visualization utilities for TabPFN, specifically adding the plot_regression_distribution function to plot predicted target distributions for single samples, along with an example script and a new viz optional dependency. The review feedback is highly constructive and should be addressed to improve robustness and compatibility. Key recommendations include ensuring backward compatibility with Matplotlib versions prior to 3.7.0, adding input validation for the plotting parameters, capping the smoothing filter size to prevent out-of-bounds issues, and handling cases where the visible density is zero to avoid y-axis rendering warnings.

Comment thread examples/plot_regression_distribution.py Outdated
Comment on lines +72 to +80
# Local import because matplotlib is an optional dependency.
try:
import matplotlib.pyplot as plt # noqa: PLC0415
from matplotlib.patches import Patch # noqa: PLC0415
except ModuleNotFoundError as err:
raise ModuleNotFoundError(
"matplotlib is required for plotting. "
'Install it with `pip install "tabpfn[viz]"`'
) from err
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To prevent unexpected runtime errors and ensure robust defensive programming, we should validate the input arguments (quantile_interval, zoom_quantile, smooth, and statistics) at the beginning of the function before performing any computations or plotting.

    if quantile_interval is not None:
        if len(quantile_interval) != 2:
            raise ValueError("quantile_interval must be a tuple of length 2.")
        if not all(0 <= q <= 1 for q in quantile_interval):
            raise ValueError("quantile_interval values must be between 0 and 1.")
        if quantile_interval[0] > quantile_interval[1]:
            raise ValueError("quantile_interval must be in ascending order (low, high).")

    if zoom_quantile is not None and not (0 < zoom_quantile <= 1):
        raise ValueError("zoom_quantile must be in the range (0, 1].")

    if smooth < 0:
        raise ValueError("smooth must be non-negative.")

    for name in statistics:
        if name not in _STAT_STYLES:
            raise ValueError(
                f"Unknown statistic '{name}'. Supported statistics are: {list(_STAT_STYLES.keys())}"
            )

    # Local import because matplotlib is an optional dependency.
    try:
        import matplotlib.pyplot as plt  # noqa: PLC0415
        from matplotlib.patches import Patch  # noqa: PLC0415
    except ModuleNotFoundError as err:
        raise ModuleNotFoundError(
            "matplotlib is required for plotting. "
            'Install it with `pip install "tabpfn[viz]"`'
        ) from err

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure about this, as it may bloat the codebase. I am leaving the comment open, to be considered by the reviewer. Thank you!

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think these checks would be fair to add, but not essential. Also serve as documentation.

Comment thread src/tabpfn/visualization/regression_distribution.py
Comment thread src/tabpfn/visualization/regression_distribution.py
eliott-kalfon and others added 2 commits May 29, 2026 16:25
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Copy link
Copy Markdown
Collaborator

@bejaeger bejaeger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @eliott-kalfon , looks great!
Just a suggestion for the API and a few nits.
Also, let's move to british spelling -> "visualisation" everywhere.

Comment on lines +72 to +80
# Local import because matplotlib is an optional dependency.
try:
import matplotlib.pyplot as plt # noqa: PLC0415
from matplotlib.patches import Patch # noqa: PLC0415
except ModuleNotFoundError as err:
raise ModuleNotFoundError(
"matplotlib is required for plotting. "
'Install it with `pip install "tabpfn[viz]"`'
) from err
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think these checks would be fair to add, but not essential. Also serve as documentation.


def plot_regression_distribution(
regressor: TabPFNRegressor,
x: object,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure about the interface of this one. Let's say the user does not only want to visualize it but also want's to do something else with the output of the prediction.
What about not passing the full regressor and x and predict inside this plotting function, and rather just passing the output of regressor.predict(..., output_type="full")?
Might need a small input validation as well.

for ax, idx, label in zip(
axes,
[low_idx, mid_idx, high_idx],
["low prediction", "median prediction", "high prediction"],
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you explain in the plot what low, median, and high predictions mean?

[low_idx, mid_idx, high_idx],
["low prediction", "median prediction", "high prediction"],
):
plot_regression_distribution(reg, X_test[idx], ax=ax)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you change the interface (see other comment) you can also predict three points in a batch and then just pass the results here, so it's a bit more flexible.


for name in statistics:
value = float(np.atleast_1d(out[name])[0])
c, ls = _STAT_STYLES[name]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would throw a cryptic error if name is not in STAT_STYLES. Maybe raise a nicer error or add types for the modes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants