Skip to content

Commit a27fa1a

Browse files
xai with full data
1 parent 1bcea70 commit a27fa1a

2 files changed

Lines changed: 81 additions & 196 deletions

File tree

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotpython"
10-
version = "0.15.32"
10+
version = "0.15.34"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]
@@ -30,7 +30,7 @@ dependencies = [
3030
"graphviz",
3131
"matplotlib",
3232
"mkdocs>=1.6.0",
33-
"mkdocs-material>=9.5.32",
33+
"mkdocs-material>=9.5.33",
3434
"mkdocstrings-python>=1.10.8",
3535
"mkdocs-exclude>=1.0.2",
3636
"mkdocs-gen-files>=0.5.0",

src/spotpython/plot/xai.py

Lines changed: 79 additions & 194 deletions
Original file line numberDiff line numberDiff line change
@@ -18,119 +18,94 @@
1818
from spotpython.data.lightdatamodule import LightDataModule
1919

2020

21-
def get_activations(net, fun_control, batch_size, device="cpu") -> dict:
22-
"""
23-
Get the average activations of each neuron in the neural network's linear layers.
21+
def check_for_nans(data, layer_index):
22+
"""Checks for NaN values in the tensor data.
2423
2524
Args:
26-
net (object): A neural network.
27-
fun_control (dict): A dictionary with the function control.
28-
batch_size (int, optional): The batch size.
29-
device (str, optional): The device to use. Defaults to "cpu".
25+
data (torch.Tensor): The tensor to check for NaN values.
26+
layer_index (int): The index of the layer for logging purposes.
3027
3128
Returns:
32-
dict: A dictionary with the average activations of the neurons in the network.
29+
bool: True if NaNs are found, False otherwise.
3330
"""
34-
activations = {}
35-
net.eval()
36-
37-
dataset = fun_control["data_set"]
38-
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
39-
40-
layer_sums = {}
41-
layer_counts = {}
42-
43-
with torch.no_grad():
44-
for inputs, _ in dataloader:
45-
inputs = inputs.to(device)
46-
47-
for layer_index, layer in enumerate(net.layers): # Iterate over layers
48-
inputs = layer(inputs)
31+
if torch.isnan(data).any():
32+
print(f"NaN detected after layer {layer_index}")
33+
return True
34+
return False
4935

50-
if isinstance(layer, nn.Linear):
51-
n_neurons = inputs.size(1)
52-
if layer_index not in layer_sums:
53-
layer_sums[layer_index] = torch.zeros(n_neurons, device=device)
54-
layer_counts[layer_index] = 0
5536

56-
# Sum activation for each neuron
57-
layer_sums[layer_index] += inputs.sum(dim=0)
58-
layer_counts[layer_index] += inputs.size(0) # Add up batch size
59-
60-
# Compute average activations
61-
for layer_index, total_sum in layer_sums.items():
62-
activations[layer_index] = (total_sum / layer_counts[layer_index]).cpu().numpy()
63-
64-
return activations
65-
66-
67-
def get_activations_full(net, fun_control, batch_size, device="cpu") -> dict:
68-
"""
69-
Get the activations of a neural network.
37+
def get_activations(net, fun_control, batch_size, device="cpu") -> tuple:
38+
"""Computes the activations for each layer of the network and
39+
the mean activations for each layer. Both are returned as a dictionary.
7040
7141
Args:
72-
net (object):
73-
A neural network.
74-
fun_control (dict):
75-
A dictionary with the function control.
76-
batch_size (int, optional):
77-
The batch size.
78-
device (str, optional):
79-
The device to use. Defaults to "cpu".
42+
net (nn.Module): The neural network model.
43+
fun_control (dict): A dictionary containing the dataset.
44+
device (str): The device to run the model on. Defaults to "cpu".
8045
8146
Returns:
82-
dict: A dictionary with the activations of the neural network.
47+
tuple: A tuple containing the activations and mean activations for each layer.
8348
8449
Examples:
85-
>>> from torch.utils.data import DataLoader
86-
from spotpython.utils.init import fun_control_init
87-
from spotpython.hyperparameters.values import set_control_key_value
88-
from spotpython.data.diabetes import Diabetes
89-
from spotpython.light.regression.netlightregression import NetLightRegression
90-
from spotpython.hyperdict.light_hyper_dict import LightHyperDict
91-
from spotpython.hyperparameters.values import add_core_model_to_fun_control
92-
from spotpython.hyperparameters.values import (
93-
get_default_hyperparameters_as_array, get_one_config_from_X)
94-
from spotpython.hyperparameters.values import set_control_key_value
95-
from spotpython.plot.xai import get_activations
96-
fun_control = fun_control_init(
97-
_L_in=10, # 10: diabetes
98-
_L_out=1,
99-
)
100-
dataset = Diabetes()
101-
set_control_key_value(control_dict=fun_control,
102-
key="data_set",
103-
value=dataset,
104-
replace=True)
105-
add_core_model_to_fun_control(fun_control=fun_control,
106-
core_model=NetLightRegression,
107-
hyper_dict=LightHyperDict)
108-
X = get_default_hyperparameters_as_array(fun_control)
109-
config = get_one_config_from_X(X, fun_control)
110-
_L_in = fun_control["_L_in"]
111-
_L_out = fun_control["_L_out"]
112-
model = fun_control["core_model"](**config, _L_in=_L_in, _L_out=_L_out)
113-
batch_size= config["batch_size"]
114-
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
115-
get_activations(model, fun_control=fun_control, batch_size=batch_size, device = "cpu")
116-
{0: array([ 1.43207282e-01, 6.29711570e-03, 1.04200505e-01, -3.79187055e-03,
117-
-1.74976081e-01, -7.97475874e-02, -2.00860098e-01, 2.48444706e-01, ...
118-
50+
>>> from spotpython.plot.xai import get_activations
51+
activations, mean_activations = get_activations(net, fun_control)
11952
"""
12053
activations = {}
121-
net.eval()
54+
mean_activations = {}
55+
net.eval() # Set the model to evaluation mode
12256
dataset = fun_control["data_set"]
12357
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
12458
inputs, _ = next(iter(dataloader))
59+
inputs = inputs.to(device)
60+
61+
# Normalize input data
62+
inputs = (inputs - inputs.mean()) / inputs.std()
63+
12564
with torch.no_grad():
126-
layer_index = 0
127-
inputs = inputs.to(device)
12865
inputs = inputs.view(inputs.size(0), -1)
129-
for layer_index, layer in enumerate(net.layers):
130-
inputs = layer(inputs)
66+
# Loop through all layers
67+
for layer_index, layer in enumerate(net.layers[:-1]):
68+
inputs = layer(inputs) # Forward pass through the layer
69+
70+
# Check for NaNs
71+
if check_for_nans(inputs, layer_index):
72+
break
73+
74+
# Collect activations for Linear layers
13175
if isinstance(layer, nn.Linear):
13276
activations[layer_index] = inputs.view(-1).cpu().numpy()
133-
return activations
77+
mean_activations[layer_index] = inputs.mean(dim=0).cpu().numpy()
78+
return activations, mean_activations
79+
80+
81+
def visualize_activations_distributions(activations, net, color="C0", columns=4, bins=50, show=True) -> None:
82+
"""Plots the distribution of activations for each layer
83+
that were determined via the get_activations function.
84+
85+
Args:
86+
activations (dict): A dictionary containing activations for each layer.
87+
net (nn.Module): The neural network model.
88+
color (str): The color for the plot histogram. Defaults to "C0".
89+
columns (int): The number of columns for the subplots. Defaults to 4.
90+
bins (int): The number of bins for the histogram. Defaults to 50.
91+
show (bool): Whether to show the plot. Defaults to True.
92+
93+
Returns:
94+
None
95+
"""
96+
rows = math.ceil(len(activations) / columns)
97+
fig, ax = plt.subplots(rows, columns, figsize=(columns * 2.7, rows * 2.5))
98+
fig_index = 0
99+
for key in activations:
100+
key_ax = ax[fig_index // columns][fig_index % columns]
101+
sns.histplot(data=activations[key], bins=bins, ax=key_ax, color=color, kde=True, stat="density")
102+
key_ax.set_title(f"Layer {key} - {net.layers[key].__class__.__name__}")
103+
fig_index += 1
104+
fig.suptitle("Activation distribution", fontsize=14)
105+
fig.subplots_adjust(hspace=0.4, wspace=0.4)
106+
if show:
107+
plt.show()
108+
plt.close()
134109

135110

136111
def get_weights(net, return_index=False) -> dict:
@@ -209,7 +184,6 @@ def get_weights(net, return_index=False) -> dict:
209184
index.append(int(name.split(".")[1]))
210185
key_name = f"Layer {name.split('.')[1]}"
211186
weights[key_name] = param.detach().view(-1).cpu().numpy()
212-
# print(f"weights: {weights}")
213187
if return_index:
214188
return weights, index
215189
else:
@@ -269,19 +243,15 @@ def get_gradients(net, fun_control, batch_size, device="cpu") -> dict:
269243
0.02890352, 0.0114617 , 0.08183316, 0.2495192 , 0.5108763 ,
270244
0.14668094, -0.07902834, 0.00912531, 0.02640062, 0.14108546, ...
271245
"""
272-
grads = {}
273246
net.eval()
274247
dataset = fun_control["data_set"]
275-
# Create DataLoader
276248
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
277-
# for batch in dataloader:
278-
# inputs, targets = batch
279-
# small_loader = data.DataLoader(train_set, batch_size=1024)
280249
inputs, targets = next(iter(dataloader))
281250
inputs, targets = inputs.to(device), targets.to(device)
282251
# Pass one batch through the network, and calculate the gradients for the weights
283252
net.zero_grad()
284253
preds = net(inputs)
254+
preds = preds.squeeze(-1) # Remove the last dimension if it's 1
285255
# TODO: Add more loss functions
286256
loss = F.mse_loss(preds, targets)
287257
# loss = F.cross_entropy(preds, labels) # Same as nn.CrossEntropyLoss, but as a function instead of module
@@ -332,67 +302,6 @@ def plot_nn_values_hist(nn_values, net, nn_values_names="", color="C0", columns=
332302
plt.show()
333303

334304

335-
def old_plot_nn_values_scatter(
336-
nn_values, nn_values_names="", absolute=True, cmap="gray", figsize=(6, 6), return_reshaped=False
337-
):
338-
"""
339-
Plot the values of a neural network.
340-
Can be used to plot the weights, gradients, or activations of a neural network.
341-
342-
Args:
343-
nn_values (dict):
344-
A dictionary with the values of the neural network. For example,
345-
the weights, gradients, or activations.
346-
nn_values_names (str, optional):
347-
The name of the values. Defaults to "".
348-
absolute (bool, optional):
349-
Whether to use the absolute values. Defaults to True.
350-
cmap (str, optional):
351-
The colormap to use. Defaults to "gray".
352-
figsize (tuple, optional):
353-
The figure size. Defaults to (6, 6).
354-
return_reshaped (bool, optional):
355-
Whether to return the reshaped values. Defaults to False.
356-
357-
"""
358-
if cmap == "gray":
359-
cmap = "gray"
360-
elif cmap == "BlueWhiteRed":
361-
cmap = colors.LinearSegmentedColormap.from_list("", ["blue", "white", "red"])
362-
elif cmap == "GreenYellowRed":
363-
cmap = colors.LinearSegmentedColormap.from_list("", ["green", "yellow", "red"])
364-
else:
365-
cmap = "viridis"
366-
367-
res = {}
368-
for layer, values in nn_values.items():
369-
k = len(values)
370-
print(f"{k} values in Layer {layer}.")
371-
if is_square(k):
372-
n = int(math.sqrt(k))
373-
else:
374-
n = int(math.sqrt(len(values)) + 1)
375-
padding = np.zeros(n * n - len(values)) # create a zero array for padding
376-
print(f"{len(padding)} padding values added.")
377-
values = np.concatenate((values, padding)) # append the padding to the values
378-
379-
print(f"{len(values)} values in Layer {layer}.")
380-
if absolute:
381-
reshaped_values = np.abs(values.reshape((n, n)))
382-
else:
383-
reshaped_values = values.reshape((n, n))
384-
385-
plt.figure(figsize=figsize)
386-
plt.imshow(reshaped_values, cmap=cmap) # use colormap to indicate the values
387-
plt.colorbar(label="Value")
388-
plt.title(f"{nn_values_names} Plot for {layer}")
389-
plt.show()
390-
# add reshaped_values to the dictionary res
391-
res[layer] = reshaped_values
392-
if return_reshaped:
393-
return res
394-
395-
396305
def plot_nn_values_scatter(
397306
nn_values, nn_values_names="", absolute=True, cmap="gray", figsize=(6, 6), return_reshaped=False, show=True
398307
) -> dict:
@@ -466,32 +375,6 @@ def plot_nn_values_scatter(
466375
return res
467376

468377

469-
def visualize_activations_distributions(net, fun_control, batch_size, device="cpu", color="C0", columns=2) -> None:
470-
"""
471-
Plots a histogram of the activations of a neural network.
472-
473-
Args:
474-
net (object):
475-
A neural network.
476-
fun_control (dict):
477-
A dictionary with the function control.
478-
batch_size (int, optional):
479-
The batch size.
480-
device (str, optional):
481-
The device to use. Defaults to "cpu".
482-
color (str, optional):
483-
The color to use. Defaults to "C0".
484-
columns (int, optional):
485-
The number of columns. Defaults to 2.
486-
487-
Returns:
488-
None
489-
490-
"""
491-
activations = get_activations_full(net, fun_control, batch_size, device)
492-
plot_nn_values_hist(activations, net, nn_values_names="Activations", color=color, columns=columns)
493-
494-
495378
def visualize_weights_distributions(net, color="C0", columns=2) -> None:
496379
"""
497380
Plot the weights distributions of a neural network.
@@ -546,17 +429,15 @@ def visualize_gradient_distributions(
546429
plot_nn_values_hist(grads, net, nn_values_names="Gradients", color=color, columns=columns)
547430

548431

549-
def visualize_activations(net, fun_control, batch_size, device, absolute=True, cmap="gray", figsize=(6, 6)) -> None:
432+
def visualize_mean_activations(mean_activations, absolute=True, cmap="gray", figsize=(6, 6)) -> None:
550433
"""
551-
Scatter plots the activations of a neural network.
434+
Scatter plots the mean activations of a neural network for each layer.
435+
means_activations is a dictionary with the mean activations of the neural network computed via
436+
the get_activations function.
552437
553438
Args:
554-
net (object):
555-
A neural network.
556-
fun_control (dict):
557-
A dictionary with the function control.
558-
batch_size (int, optional):
559-
The batch size.
439+
mean_activations (dict):
440+
A dictionary with the mean activations of the neural network.
560441
device (str, optional):
561442
The device to use.
562443
absolute (bool, optional):
@@ -569,10 +450,14 @@ def visualize_activations(net, fun_control, batch_size, device, absolute=True, c
569450
Returns:
570451
None
571452
453+
Examples:
454+
>>> from spotpython.plot.xai import get_activations
455+
activations, mean_activations = get_activations(net, fun_control)
456+
visualize_mean_activations(mean_activations
457+
572458
"""
573-
activations = get_activations(net, fun_control, batch_size, device)
574459
plot_nn_values_scatter(
575-
nn_values=activations, nn_values_names="Average Activations", absolute=absolute, cmap=cmap, figsize=figsize
460+
nn_values=mean_activations, nn_values_names="Average Activations", absolute=absolute, cmap=cmap, figsize=figsize
576461
)
577462

578463

0 commit comments

Comments
 (0)