@@ -11,14 +11,22 @@ def plot_mo(
1111 y_rf : np .ndarray = None ,
1212 pareto_front : bool = False ,
1313 y_best : np .ndarray = None ,
14+ y_add : np .ndarray = None ,
15+ y_add2 : np .ndarray = None ,
16+ y_add_color = "blue" ,
17+ y_add2_color = "green" ,
1418 title : str = "" ,
1519 y_orig : np .ndarray = None ,
1620 pareto_front_orig : bool = False ,
1721 pareto_label : bool = False ,
1822 y_rf_color = "blue" ,
1923 y_best_color = "red" ,
20- x_axis_transformation : str = "id" , # New argument for x-axis transformation
21- y_axis_transformation : str = "id" , # New argument for y-axis transformation
24+ x_axis_transformation : str = "id" ,
25+ y_axis_transformation : str = "id" ,
26+ y_best_label = "Best" ,
27+ y_add_label = "Add" ,
28+ y_add2_label = "Add2" ,
29+ filename : str = None ,
2230) -> None :
2331 """
2432 Generates scatter plots for each combination of two targets from a multi-output prediction while highlighting Pareto optimal points.
@@ -30,6 +38,13 @@ def plot_mo(
3038 pareto (str): Specifies whether to compute Pareto front based on 'min' or 'max' criterion.
3139 pareto_front (bool): If True, connect Pareto optimal points with a red line for y_rf.
3240 y_best (np.ndarray, optional): A NumPy array representing the best point to highlight in red. Defaults to None.
41+ y_add (np.ndarray, optional): A NumPy array representing the additional points to highlight in blue. Defaults to None.
42+ y_add2 (np.ndarray, optional): A NumPy array representing the additional points to highlight in green. Defaults to None.
43+ y_add_color (str): The color of the additional points. Defaults to "blue".
44+ y_add2_color (str): The color of the additional points. Defaults to "green".
45+ y_best_label (str): The label for the best point. Defaults to "Best".
46+ y_add_label (str): The label for the additional points. Defaults to "Add".
47+ y_add2_label (str): The label for the additional points. Defaults to "Add2".
3348 title (str): The title of the plot. Defaults to "" (empty string).
3449 y_orig (np.ndarray, optional): The original target values with shape (n_samples, n_targets). Defaults to None.
3550 pareto_front_orig (bool): If True, connect Pareto optimal points with a light blue line for y_orig. Defaults to False.
@@ -38,9 +53,11 @@ def plot_mo(
3853 y_best_color (str): The color of the best point. Defaults to "red".
3954 x_axis_transformation (str): Transformation for the x-axis. Options are "id" (linear), "log" (logarithmic), and "loglog" (log-log). Defaults to "id".
4055 y_axis_transformation (str): Transformation for the y-axis. Options are "id" (linear), "log" (logarithmic), and "loglog" (log-log). Defaults to "id".
56+ filename (str, optional):
57+ If provided, saves the plot to the specified file. Supports "pdf" and "png" formats. Defaults to None.
4158
4259 Returns:
43- None: Displays the plot.
60+ None: Displays or saves the plot.
4461
4562 Examples:
4663 >>> from spotpython.mo.plot import plot_mo
@@ -50,7 +67,7 @@ def plot_mo(
5067 >>> pareto = "min"
5168 >>> y_rf = np.random.rand(100, 2)
5269 >>> y_orig = np.random.rand(100, 2)
53- >>> plot_mo(target_names, combinations, pareto, y_rf=y_rf, y_orig=y_orig)
70+ >>> plot_mo(target_names, combinations, pareto, y_rf=y_rf, y_orig=y_orig, filename="plot.png" )
5471 """
5572 # Convert y_rf to numpy array if it's a pandas DataFrame
5673 if isinstance (y_rf , pd .DataFrame ):
@@ -72,14 +89,14 @@ def plot_mo(
7289 if y_orig is not None :
7390 minimize = pareto == "min"
7491 pareto_mask_orig = is_pareto_efficient (y_orig [:, [i , j ]], minimize )
75- plt .scatter (y_orig [:, i ], y_orig [:, j ], edgecolor = "w" , c = "gray" , s = s , marker = "o" , alpha = a , label = "Original Points" )
76- plt .scatter (y_orig [pareto_mask_orig , i ], y_orig [pareto_mask_orig , j ], edgecolor = "k" , c = "gray" , s = pareto_size , marker = "o" , alpha = a , label = "Original Pareto" )
92+ plt .scatter (y_orig [:, i ], y_orig [:, j ], edgecolor = "w" , c = "gray" , s = s , marker = "o" , alpha = a , label = "Non-Pareto Points" )
93+ plt .scatter (y_orig [pareto_mask_orig , i ], y_orig [pareto_mask_orig , j ], edgecolor = "k" , c = "gray" , s = pareto_size , marker = "o" , alpha = a , label = "Pareto Points " )
7794 if pareto_label :
7895 for idx in np .where (pareto_mask_orig )[0 ]:
7996 plt .text (y_orig [idx , i ], y_orig [idx , j ], str (idx ), color = "black" , fontsize = 8 , ha = "center" , va = "center" )
8097 if pareto_front_orig :
8198 sorted_indices_orig = np .argsort (y_orig [pareto_mask_orig , i ])
82- plt .plot (y_orig [pareto_mask_orig , i ][sorted_indices_orig ], y_orig [pareto_mask_orig , j ][sorted_indices_orig ], "k-" , alpha = a , label = "Original Pareto Front" )
99+ plt .plot (y_orig [pareto_mask_orig , i ][sorted_indices_orig ], y_orig [pareto_mask_orig , j ][sorted_indices_orig ], "k-" , alpha = a , label = "Pareto Front" )
83100
84101 if y_rf is not None :
85102 minimize = pareto == "min"
@@ -101,7 +118,11 @@ def plot_mo(
101118 )
102119
103120 if y_best is not None :
104- plt .scatter (y_best [:, i ], y_best [:, j ], edgecolor = "k" , c = y_best_color , s = s , marker = "D" , alpha = 1 , label = "Best" )
121+ plt .scatter (y_best [:, i ], y_best [:, j ], edgecolor = "k" , c = y_best_color , s = s , marker = "D" , alpha = 1 , label = y_best_label )
122+ if y_add is not None :
123+ plt .scatter (y_add [:, i ], y_add [:, j ], edgecolor = "k" , c = y_add_color , s = s , marker = "D" , alpha = 1 , label = y_add_label )
124+ if y_add2 is not None :
125+ plt .scatter (y_add2 [:, i ], y_add2 [:, j ], edgecolor = "k" , c = y_add2_color , s = s , marker = "D" , alpha = 1 , label = y_add2_label )
105126
106127 # Apply axis transformations
107128 if x_axis_transformation == "log" :
@@ -117,4 +138,11 @@ def plot_mo(
117138 plt .grid ()
118139 plt .title (title )
119140 plt .legend ()
120- plt .show ()
141+ # Save or show the plot
142+ if filename :
143+ if filename .endswith (".pdf" ) or filename .endswith (".png" ):
144+ plt .savefig (filename , format = filename .split ("." )[- 1 ])
145+ else :
146+ raise ValueError ("Filename must have a valid suffix: '.pdf' or '.png'." )
147+ else :
148+ plt .show ()
0 commit comments