Skip to content

Commit b0fe5d2

Browse files
0.16.4
1 parent 0266d64 commit b0fe5d2

2 files changed

Lines changed: 6 additions & 5 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotpython"
10-
version = "0.16.3"
10+
version = "0.16.4"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]

src/spotpython/plot/xai.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -444,26 +444,27 @@ def plot_nn_values_scatter(
444444

445445
if absolute:
446446
reshaped_values = np.abs(values).reshape((height, width))
447-
# Mark padding values distinctly by setting them back to NaN
448447
reshaped_values[reshaped_values == np.abs(padding_marker)] = np.nan
449448
else:
450449
reshaped_values = values.reshape((height, width))
451450

452-
_, ax = plt.subplots(figsize=figsize)
451+
fig, ax = plt.subplots(figsize=figsize)
453452
cax = ax.imshow(reshaped_values, cmap=cmap, interpolation="nearest")
454453

454+
# Adjust the position and size of the colorbar
455+
cbar = fig.colorbar(cax, ax=ax, fraction=0.046, pad=0.04)
456+
455457
for i in range(height):
456458
for j in range(width):
457459
if np.isnan(reshaped_values[i, j]):
458460
ax.text(j, i, "P", ha="center", va="center", color="red")
459461

460-
plt.colorbar(cax, label="Value")
461462
plt.title(f"{nn_values_names} Plot for {layer}")
462463
if show:
463464
plt.show()
464465

465-
# Add reshaped_values to the dictionary res
466466
res[layer] = reshaped_values
467+
467468
if return_reshaped:
468469
return res
469470

0 commit comments

Comments
 (0)