Skip to content

Commit 9d6a9f6

Browse files
0.15.35
xai cleanup
1 parent a27fa1a commit 9d6a9f6

2 files changed

Lines changed: 63 additions & 23 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotpython"
10-
version = "0.15.34"
10+
version = "0.15.35"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]

src/spotpython/plot/xai.py

Lines changed: 62 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from spotpython.data.lightdatamodule import LightDataModule
1919

2020

21-
def check_for_nans(data, layer_index):
21+
def check_for_nans(data, layer_index) -> bool:
2222
"""Checks for NaN values in the tensor data.
2323
2424
Args:
@@ -34,14 +34,16 @@ def check_for_nans(data, layer_index):
3434
return False
3535

3636

37-
def get_activations(net, fun_control, batch_size, device="cpu") -> tuple:
37+
def get_activations(net, fun_control, batch_size, device="cpu", normalize=True) -> tuple:
3838
"""Computes the activations for each layer of the network and
3939
the mean activations for each layer. Both are returned as a dictionary.
4040
4141
Args:
4242
net (nn.Module): The neural network model.
4343
fun_control (dict): A dictionary containing the dataset.
4444
device (str): The device to run the model on. Defaults to "cpu".
45+
batch_size (int): The batch size for the data loader.
46+
normalize (bool): Whether to normalize the input data. Defaults to True.
4547
4648
Returns:
4749
tuple: A tuple containing the activations and mean activations for each layer.
@@ -54,12 +56,19 @@ def get_activations(net, fun_control, batch_size, device="cpu") -> tuple:
5456
mean_activations = {}
5557
net.eval() # Set the model to evaluation mode
5658
dataset = fun_control["data_set"]
57-
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
58-
inputs, _ = next(iter(dataloader))
59+
data_module = LightDataModule(
60+
dataset=dataset,
61+
batch_size=batch_size,
62+
test_size=fun_control["test_size"],
63+
scaler=fun_control["scaler"],
64+
verbosity=10,
65+
)
66+
data_module.setup(stage="test")
67+
test_loader = data_module.test_dataloader()
68+
inputs, _ = next(iter(test_loader))
5969
inputs = inputs.to(device)
60-
61-
# Normalize input data
62-
inputs = (inputs - inputs.mean()) / inputs.std()
70+
if normalize:
71+
inputs = (inputs - inputs.mean()) / inputs.std()
6372

6473
with torch.no_grad():
6574
inputs = inputs.view(inputs.size(0), -1)
@@ -190,7 +199,7 @@ def get_weights(net, return_index=False) -> dict:
190199
return weights
191200

192201

193-
def get_gradients(net, fun_control, batch_size, device="cpu") -> dict:
202+
def get_gradients(net, fun_control, batch_size, device="cpu", normalize=True) -> dict:
194203
"""
195204
Get the gradients of a neural network.
196205
@@ -203,6 +212,8 @@ def get_gradients(net, fun_control, batch_size, device="cpu") -> dict:
203212
The batch size.
204213
device (str, optional):
205214
The device to use. Defaults to "cpu".
215+
normalize (bool, optional):
216+
Whether to normalize the input data. Defaults to True.
206217
207218
Returns:
208219
dict: A dictionary with the gradients of the neural network.
@@ -245,8 +256,18 @@ def get_gradients(net, fun_control, batch_size, device="cpu") -> dict:
245256
"""
246257
net.eval()
247258
dataset = fun_control["data_set"]
248-
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
249-
inputs, targets = next(iter(dataloader))
259+
data_module = LightDataModule(
260+
dataset=dataset,
261+
batch_size=batch_size,
262+
test_size=fun_control["test_size"],
263+
scaler=fun_control["scaler"],
264+
verbosity=10,
265+
)
266+
data_module.setup(stage="test")
267+
test_loader = data_module.test_dataloader()
268+
inputs, targets = next(iter(test_loader))
269+
if normalize:
270+
inputs = (inputs - inputs.mean()) / inputs.std()
250271
inputs, targets = inputs.to(device), targets.to(device)
251272
# Pass one batch through the network, and calculate the gradients for the weights
252273
net.zero_grad()
@@ -396,7 +417,16 @@ def visualize_weights_distributions(net, color="C0", columns=2) -> None:
396417

397418

398419
def visualize_gradient_distributions(
399-
net, fun_control, batch_size, device="cpu", color="C0", xlabel=None, stat="count", use_kde=True, columns=2
420+
net,
421+
fun_control,
422+
batch_size,
423+
device="cpu",
424+
color="C0",
425+
xlabel=None,
426+
stat="count",
427+
use_kde=True,
428+
columns=2,
429+
normalize=True,
400430
) -> None:
401431
"""
402432
Plot the gradients distributions of a neural network.
@@ -420,12 +450,14 @@ def visualize_gradient_distributions(
420450
Whether to use kde. Defaults to True.
421451
columns (int, optional):
422452
The number of columns. Defaults to 2.
453+
normalize (bool, optional):
454+
Whether to normalize the input data. Defaults to True.
423455
424456
Returns:
425457
None
426458
427459
"""
428-
grads = get_gradients(net, fun_control, batch_size, device)
460+
grads = get_gradients(net, fun_control, batch_size, device, normalize=normalize)
429461
plot_nn_values_hist(grads, net, nn_values_names="Gradients", color=color, columns=columns)
430462

431463

@@ -438,8 +470,6 @@ def visualize_mean_activations(mean_activations, absolute=True, cmap="gray", fig
438470
Args:
439471
mean_activations (dict):
440472
A dictionary with the mean activations of the neural network.
441-
device (str, optional):
442-
The device to use.
443473
absolute (bool, optional):
444474
Whether to use the absolute values. Defaults to True.
445475
cmap (str, optional):
@@ -483,7 +513,9 @@ def visualize_weights(net, absolute=True, cmap="gray", figsize=(6, 6)) -> None:
483513
plot_nn_values_scatter(nn_values=weights, nn_values_names="Weights", absolute=absolute, cmap=cmap, figsize=figsize)
484514

485515

486-
def visualize_gradients(net, fun_control, batch_size, absolute=True, cmap="gray", figsize=(6, 6), device="cpu") -> None:
516+
def visualize_gradients(
517+
net, fun_control, batch_size, absolute=True, cmap="gray", figsize=(6, 6), device="cpu", normalize=True
518+
) -> None:
487519
"""
488520
Scatter plots the gradients of a neural network.
489521
@@ -502,15 +534,18 @@ def visualize_gradients(net, fun_control, batch_size, absolute=True, cmap="gray"
502534
The figure size. Defaults to (6, 6).
503535
device (str, optional):
504536
The device to use. Defaults to "cpu".
537+
normalize (bool, optional):
538+
Whether to normalize the input data. Defaults to True.
505539
506540
Returns:
507541
None
508542
"""
509543
grads = get_gradients(
510-
net,
511-
fun_control,
544+
net=net,
545+
fun_control=fun_control,
512546
batch_size=batch_size,
513547
device=device,
548+
normalize=normalize,
514549
)
515550
plot_nn_values_scatter(nn_values=grads, nn_values_names="Gradients", absolute=absolute, cmap=cmap, figsize=figsize)
516551

@@ -523,6 +558,7 @@ def get_attributions(
523558
abs_attr=True,
524559
n_rel=5,
525560
device="cpu",
561+
normalize=True,
526562
) -> pd.DataFrame:
527563
"""Get the attributions of a neural network.
528564
@@ -541,6 +577,8 @@ def get_attributions(
541577
The number of relevant features. Defaults to 5.
542578
device (str, optional):
543579
The device to use. Defaults to "cpu".
580+
normalize (bool, optional):
581+
Whether to normalize the input data. Defaults to True.
544582
545583
Returns:
546584
pd.DataFrame (object): A DataFrame with the attributions.
@@ -597,6 +635,8 @@ def get_attributions(
597635
"""
598636
)
599637
for inputs, _ in test_loader:
638+
if normalize:
639+
inputs = (inputs - inputs.mean()) / inputs.std()
600640
inputs.requires_grad_()
601641
attributions = attr.attribute(inputs, return_convergence_delta=False, baselines=baseline)
602642
if total_attributions is None:
@@ -670,7 +710,7 @@ def is_square(n) -> bool:
670710
return n == int(math.sqrt(n)) ** 2
671711

672712

673-
def get_layer_conductance(spot_tuner, fun_control, layer_idx, device="cpu") -> np.ndarray:
713+
def get_layer_conductance(spot_tuner, fun_control, layer_idx, device="cpu", normalize=True) -> np.ndarray:
674714
"""
675715
Compute the average layer conductance attributions for a specified layer in the model.
676716
@@ -683,6 +723,8 @@ def get_layer_conductance(spot_tuner, fun_control, layer_idx, device="cpu") -> n
683723
Index of the layer for which to compute layer conductance attributions.
684724
device (str, optional):
685725
The device to use. Defaults to "cpu".
726+
normalize (bool, optional):
727+
Whether to normalize the input data. Defaults to True.
686728
687729
Returns:
688730
numpy.ndarray:
@@ -710,15 +752,15 @@ def get_layer_conductance(spot_tuner, fun_control, layer_idx, device="cpu") -> n
710752
if feature_names is None:
711753
feature_names = [f"x{i}" for i in range(n_features)]
712754
batch_size = config["batch_size"]
713-
# train_loader = DataLoader(dataset, batch_size=batch_size)
714755
test_loader = DataLoader(dataset, batch_size=batch_size)
715-
716756
total_layer_attributions = None
717757
layers = model.layers
718758
print("Conductance analysis for layer: ", layers[layer_idx])
719759
lc = LayerConductance(model, layers[layer_idx])
720760

721761
for inputs, labels in test_loader:
762+
if normalize:
763+
inputs = (inputs - inputs.mean()) / inputs.std()
722764
lc_attr_test = lc.attribute(inputs, n_steps=10, attribute_to_layer_input=True)
723765
if total_layer_attributions is None:
724766
total_layer_attributions = lc_attr_test
@@ -844,8 +886,6 @@ def sort_layers(data_dict) -> dict:
844886
"""
845887
# Use a lambda function to extract the number X from "Layer X" and sort based on that number
846888
sorted_items = sorted(data_dict.items(), key=lambda item: int(item[0].split()[1]))
847-
848889
# Create a new dictionary from the sorted items
849890
sorted_dict = dict(sorted_items)
850-
851891
return sorted_dict

0 commit comments

Comments
 (0)