Skip to content

Commit a19aa96

Browse files
plot mo
1 parent 5161f43 commit a19aa96

2 files changed

Lines changed: 38 additions & 10 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotpython"
10-
version = "0.28.12"
10+
version = "0.28.13"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]

src/spotpython/mo/plot.py

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,22 @@ def plot_mo(
1111
y_rf: np.ndarray = None,
1212
pareto_front: bool = False,
1313
y_best: np.ndarray = None,
14+
y_add: np.ndarray = None,
15+
y_add2: np.ndarray = None,
16+
y_add_color="blue",
17+
y_add2_color="green",
1418
title: str = "",
1519
y_orig: np.ndarray = None,
1620
pareto_front_orig: bool = False,
1721
pareto_label: bool = False,
1822
y_rf_color="blue",
1923
y_best_color="red",
20-
x_axis_transformation: str = "id", # New argument for x-axis transformation
21-
y_axis_transformation: str = "id", # New argument for y-axis transformation
24+
x_axis_transformation: str = "id",
25+
y_axis_transformation: str = "id",
26+
y_best_label="Best",
27+
y_add_label="Add",
28+
y_add2_label="Add2",
29+
filename: str = None,
2230
) -> None:
2331
"""
2432
Generates scatter plots for each combination of two targets from a multi-output prediction while highlighting Pareto optimal points.
@@ -30,6 +38,13 @@ def plot_mo(
3038
pareto (str): Specifies whether to compute Pareto front based on 'min' or 'max' criterion.
3139
pareto_front (bool): If True, connect Pareto optimal points with a red line for y_rf.
3240
y_best (np.ndarray, optional): A NumPy array representing the best point to highlight in red. Defaults to None.
41+
y_add (np.ndarray, optional): A NumPy array representing the additional points to highlight in blue. Defaults to None.
42+
y_add2 (np.ndarray, optional): A NumPy array representing the additional points to highlight in green. Defaults to None.
43+
y_add_color (str): The color of the additional points. Defaults to "blue".
44+
y_add2_color (str): The color of the additional points. Defaults to "green".
45+
y_best_label (str): The label for the best point. Defaults to "Best".
46+
y_add_label (str): The label for the additional points. Defaults to "Add".
47+
y_add2_label (str): The label for the additional points. Defaults to "Add2".
3348
title (str): The title of the plot. Defaults to "" (empty string).
3449
y_orig (np.ndarray, optional): The original target values with shape (n_samples, n_targets). Defaults to None.
3550
pareto_front_orig (bool): If True, connect Pareto optimal points with a light blue line for y_orig. Defaults to False.
@@ -38,9 +53,11 @@ def plot_mo(
3853
y_best_color (str): The color of the best point. Defaults to "red".
3954
x_axis_transformation (str): Transformation for the x-axis. Options are "id" (linear), "log" (logarithmic), and "loglog" (log-log). Defaults to "id".
4055
y_axis_transformation (str): Transformation for the y-axis. Options are "id" (linear), "log" (logarithmic), and "loglog" (log-log). Defaults to "id".
56+
filename (str, optional):
57+
If provided, saves the plot to the specified file. Supports "pdf" and "png" formats. Defaults to None.
4158
4259
Returns:
43-
None: Displays the plot.
60+
None: Displays or saves the plot.
4461
4562
Examples:
4663
>>> from spotpython.mo.plot import plot_mo
@@ -50,7 +67,7 @@ def plot_mo(
5067
>>> pareto = "min"
5168
>>> y_rf = np.random.rand(100, 2)
5269
>>> y_orig = np.random.rand(100, 2)
53-
>>> plot_mo(target_names, combinations, pareto, y_rf=y_rf, y_orig=y_orig)
70+
>>> plot_mo(target_names, combinations, pareto, y_rf=y_rf, y_orig=y_orig, filename="plot.png")
5471
"""
5572
# Convert y_rf to numpy array if it's a pandas DataFrame
5673
if isinstance(y_rf, pd.DataFrame):
@@ -72,14 +89,14 @@ def plot_mo(
7289
if y_orig is not None:
7390
minimize = pareto == "min"
7491
pareto_mask_orig = is_pareto_efficient(y_orig[:, [i, j]], minimize)
75-
plt.scatter(y_orig[:, i], y_orig[:, j], edgecolor="w", c="gray", s=s, marker="o", alpha=a, label="Original Points")
76-
plt.scatter(y_orig[pareto_mask_orig, i], y_orig[pareto_mask_orig, j], edgecolor="k", c="gray", s=pareto_size, marker="o", alpha=a, label="Original Pareto")
92+
plt.scatter(y_orig[:, i], y_orig[:, j], edgecolor="w", c="gray", s=s, marker="o", alpha=a, label="Non-Pareto Points")
93+
plt.scatter(y_orig[pareto_mask_orig, i], y_orig[pareto_mask_orig, j], edgecolor="k", c="gray", s=pareto_size, marker="o", alpha=a, label="Pareto Points")
7794
if pareto_label:
7895
for idx in np.where(pareto_mask_orig)[0]:
7996
plt.text(y_orig[idx, i], y_orig[idx, j], str(idx), color="black", fontsize=8, ha="center", va="center")
8097
if pareto_front_orig:
8198
sorted_indices_orig = np.argsort(y_orig[pareto_mask_orig, i])
82-
plt.plot(y_orig[pareto_mask_orig, i][sorted_indices_orig], y_orig[pareto_mask_orig, j][sorted_indices_orig], "k-", alpha=a, label="Original Pareto Front")
99+
plt.plot(y_orig[pareto_mask_orig, i][sorted_indices_orig], y_orig[pareto_mask_orig, j][sorted_indices_orig], "k-", alpha=a, label="Pareto Front")
83100

84101
if y_rf is not None:
85102
minimize = pareto == "min"
@@ -101,7 +118,11 @@ def plot_mo(
101118
)
102119

103120
if y_best is not None:
104-
plt.scatter(y_best[:, i], y_best[:, j], edgecolor="k", c=y_best_color, s=s, marker="D", alpha=1, label="Best")
121+
plt.scatter(y_best[:, i], y_best[:, j], edgecolor="k", c=y_best_color, s=s, marker="D", alpha=1, label=y_best_label)
122+
if y_add is not None:
123+
plt.scatter(y_add[:, i], y_add[:, j], edgecolor="k", c=y_add_color, s=s, marker="D", alpha=1, label=y_add_label)
124+
if y_add2 is not None:
125+
plt.scatter(y_add2[:, i], y_add2[:, j], edgecolor="k", c=y_add2_color, s=s, marker="D", alpha=1, label=y_add2_label)
105126

106127
# Apply axis transformations
107128
if x_axis_transformation == "log":
@@ -117,4 +138,11 @@ def plot_mo(
117138
plt.grid()
118139
plt.title(title)
119140
plt.legend()
120-
plt.show()
141+
# Save or show the plot
142+
if filename:
143+
if filename.endswith(".pdf") or filename.endswith(".png"):
144+
plt.savefig(filename, format=filename.split(".")[-1])
145+
else:
146+
raise ValueError("Filename must have a valid suffix: '.pdf' or '.png'.")
147+
else:
148+
plt.show()

0 commit comments

Comments
 (0)