1818from spotpython .data .lightdatamodule import LightDataModule
1919
2020
21- def get_activations (net , fun_control , batch_size , device = "cpu" ) -> dict :
22- """
23- Get the average activations of each neuron in the neural network's linear layers.
21+ def check_for_nans (data , layer_index ):
22+ """Checks for NaN values in the tensor data.
2423
2524 Args:
26- net (object): A neural network.
27- fun_control (dict): A dictionary with the function control.
28- batch_size (int, optional): The batch size.
29- device (str, optional): The device to use. Defaults to "cpu".
25+ data (torch.Tensor): The tensor to check for NaN values.
26+ layer_index (int): The index of the layer for logging purposes.
3027
3128 Returns:
32- dict: A dictionary with the average activations of the neurons in the network .
29+ bool: True if NaNs are found, False otherwise .
3330 """
34- activations = {}
35- net .eval ()
36-
37- dataset = fun_control ["data_set" ]
38- dataloader = DataLoader (dataset , batch_size = batch_size , shuffle = False )
39-
40- layer_sums = {}
41- layer_counts = {}
42-
43- with torch .no_grad ():
44- for inputs , _ in dataloader :
45- inputs = inputs .to (device )
46-
47- for layer_index , layer in enumerate (net .layers ): # Iterate over layers
48- inputs = layer (inputs )
31+ if torch .isnan (data ).any ():
32+ print (f"NaN detected after layer { layer_index } " )
33+ return True
34+ return False
4935
50- if isinstance (layer , nn .Linear ):
51- n_neurons = inputs .size (1 )
52- if layer_index not in layer_sums :
53- layer_sums [layer_index ] = torch .zeros (n_neurons , device = device )
54- layer_counts [layer_index ] = 0
5536
56- # Sum activation for each neuron
57- layer_sums [layer_index ] += inputs .sum (dim = 0 )
58- layer_counts [layer_index ] += inputs .size (0 ) # Add up batch size
59-
60- # Compute average activations
61- for layer_index , total_sum in layer_sums .items ():
62- activations [layer_index ] = (total_sum / layer_counts [layer_index ]).cpu ().numpy ()
63-
64- return activations
65-
66-
67- def get_activations_full (net , fun_control , batch_size , device = "cpu" ) -> dict :
68- """
69- Get the activations of a neural network.
37+ def get_activations (net , fun_control , batch_size , device = "cpu" ) -> tuple :
38+ """Computes the activations for each layer of the network and
39+ the mean activations for each layer. Both are returned as a dictionary.
7040
7141 Args:
72- net (object):
73- A neural network.
74- fun_control (dict):
75- A dictionary with the function control.
76- batch_size (int, optional):
77- The batch size.
78- device (str, optional):
79- The device to use. Defaults to "cpu".
42+ net (nn.Module): The neural network model.
43+ fun_control (dict): A dictionary containing the dataset.
44+ device (str): The device to run the model on. Defaults to "cpu".
8045
8146 Returns:
82- dict : A dictionary with the activations of the neural network .
47+ tuple : A tuple containing the activations and mean activations for each layer .
8348
8449 Examples:
85- >>> from torch.utils.data import DataLoader
86- from spotpython.utils.init import fun_control_init
87- from spotpython.hyperparameters.values import set_control_key_value
88- from spotpython.data.diabetes import Diabetes
89- from spotpython.light.regression.netlightregression import NetLightRegression
90- from spotpython.hyperdict.light_hyper_dict import LightHyperDict
91- from spotpython.hyperparameters.values import add_core_model_to_fun_control
92- from spotpython.hyperparameters.values import (
93- get_default_hyperparameters_as_array, get_one_config_from_X)
94- from spotpython.hyperparameters.values import set_control_key_value
95- from spotpython.plot.xai import get_activations
96- fun_control = fun_control_init(
97- _L_in=10, # 10: diabetes
98- _L_out=1,
99- )
100- dataset = Diabetes()
101- set_control_key_value(control_dict=fun_control,
102- key="data_set",
103- value=dataset,
104- replace=True)
105- add_core_model_to_fun_control(fun_control=fun_control,
106- core_model=NetLightRegression,
107- hyper_dict=LightHyperDict)
108- X = get_default_hyperparameters_as_array(fun_control)
109- config = get_one_config_from_X(X, fun_control)
110- _L_in = fun_control["_L_in"]
111- _L_out = fun_control["_L_out"]
112- model = fun_control["core_model"](**config, _L_in=_L_in, _L_out=_L_out)
113- batch_size= config["batch_size"]
114- dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
115- get_activations(model, fun_control=fun_control, batch_size=batch_size, device = "cpu")
116- {0: array([ 1.43207282e-01, 6.29711570e-03, 1.04200505e-01, -3.79187055e-03,
117- -1.74976081e-01, -7.97475874e-02, -2.00860098e-01, 2.48444706e-01, ...
118-
50+ >>> from spotpython.plot.xai import get_activations
51+ activations, mean_activations = get_activations(net, fun_control)
11952 """
12053 activations = {}
121- net .eval ()
54+ mean_activations = {}
55+ net .eval () # Set the model to evaluation mode
12256 dataset = fun_control ["data_set" ]
12357 dataloader = DataLoader (dataset , batch_size = batch_size , shuffle = False )
12458 inputs , _ = next (iter (dataloader ))
59+ inputs = inputs .to (device )
60+
61+ # Normalize input data
62+ inputs = (inputs - inputs .mean ()) / inputs .std ()
63+
12564 with torch .no_grad ():
126- layer_index = 0
127- inputs = inputs .to (device )
12865 inputs = inputs .view (inputs .size (0 ), - 1 )
129- for layer_index , layer in enumerate (net .layers ):
130- inputs = layer (inputs )
66+ # Loop through all layers
67+ for layer_index , layer in enumerate (net .layers [:- 1 ]):
68+ inputs = layer (inputs ) # Forward pass through the layer
69+
70+ # Check for NaNs
71+ if check_for_nans (inputs , layer_index ):
72+ break
73+
74+ # Collect activations for Linear layers
13175 if isinstance (layer , nn .Linear ):
13276 activations [layer_index ] = inputs .view (- 1 ).cpu ().numpy ()
133- return activations
77+ mean_activations [layer_index ] = inputs .mean (dim = 0 ).cpu ().numpy ()
78+ return activations , mean_activations
79+
80+
81+ def visualize_activations_distributions (activations , net , color = "C0" , columns = 4 , bins = 50 , show = True ) -> None :
82+ """Plots the distribution of activations for each layer
83+ that were determined via the get_activations function.
84+
85+ Args:
86+ activations (dict): A dictionary containing activations for each layer.
87+ net (nn.Module): The neural network model.
88+ color (str): The color for the plot histogram. Defaults to "C0".
89+ columns (int): The number of columns for the subplots. Defaults to 4.
90+ bins (int): The number of bins for the histogram. Defaults to 50.
91+ show (bool): Whether to show the plot. Defaults to True.
92+
93+ Returns:
94+ None
95+ """
96+ rows = math .ceil (len (activations ) / columns )
97+ fig , ax = plt .subplots (rows , columns , figsize = (columns * 2.7 , rows * 2.5 ))
98+ fig_index = 0
99+ for key in activations :
100+ key_ax = ax [fig_index // columns ][fig_index % columns ]
101+ sns .histplot (data = activations [key ], bins = bins , ax = key_ax , color = color , kde = True , stat = "density" )
102+ key_ax .set_title (f"Layer { key } - { net .layers [key ].__class__ .__name__ } " )
103+ fig_index += 1
104+ fig .suptitle ("Activation distribution" , fontsize = 14 )
105+ fig .subplots_adjust (hspace = 0.4 , wspace = 0.4 )
106+ if show :
107+ plt .show ()
108+ plt .close ()
134109
135110
136111def get_weights (net , return_index = False ) -> dict :
@@ -209,7 +184,6 @@ def get_weights(net, return_index=False) -> dict:
209184 index .append (int (name .split ("." )[1 ]))
210185 key_name = f"Layer { name .split ('.' )[1 ]} "
211186 weights [key_name ] = param .detach ().view (- 1 ).cpu ().numpy ()
212- # print(f"weights: {weights}")
213187 if return_index :
214188 return weights , index
215189 else :
@@ -269,19 +243,15 @@ def get_gradients(net, fun_control, batch_size, device="cpu") -> dict:
269243 0.02890352, 0.0114617 , 0.08183316, 0.2495192 , 0.5108763 ,
270244 0.14668094, -0.07902834, 0.00912531, 0.02640062, 0.14108546, ...
271245 """
272- grads = {}
273246 net .eval ()
274247 dataset = fun_control ["data_set" ]
275- # Create DataLoader
276248 dataloader = DataLoader (dataset , batch_size = batch_size , shuffle = False )
277- # for batch in dataloader:
278- # inputs, targets = batch
279- # small_loader = data.DataLoader(train_set, batch_size=1024)
280249 inputs , targets = next (iter (dataloader ))
281250 inputs , targets = inputs .to (device ), targets .to (device )
282251 # Pass one batch through the network, and calculate the gradients for the weights
283252 net .zero_grad ()
284253 preds = net (inputs )
254+ preds = preds .squeeze (- 1 ) # Remove the last dimension if it's 1
285255 # TODO: Add more loss functions
286256 loss = F .mse_loss (preds , targets )
287257 # loss = F.cross_entropy(preds, labels) # Same as nn.CrossEntropyLoss, but as a function instead of module
@@ -332,67 +302,6 @@ def plot_nn_values_hist(nn_values, net, nn_values_names="", color="C0", columns=
332302 plt .show ()
333303
334304
335- def old_plot_nn_values_scatter (
336- nn_values , nn_values_names = "" , absolute = True , cmap = "gray" , figsize = (6 , 6 ), return_reshaped = False
337- ):
338- """
339- Plot the values of a neural network.
340- Can be used to plot the weights, gradients, or activations of a neural network.
341-
342- Args:
343- nn_values (dict):
344- A dictionary with the values of the neural network. For example,
345- the weights, gradients, or activations.
346- nn_values_names (str, optional):
347- The name of the values. Defaults to "".
348- absolute (bool, optional):
349- Whether to use the absolute values. Defaults to True.
350- cmap (str, optional):
351- The colormap to use. Defaults to "gray".
352- figsize (tuple, optional):
353- The figure size. Defaults to (6, 6).
354- return_reshaped (bool, optional):
355- Whether to return the reshaped values. Defaults to False.
356-
357- """
358- if cmap == "gray" :
359- cmap = "gray"
360- elif cmap == "BlueWhiteRed" :
361- cmap = colors .LinearSegmentedColormap .from_list ("" , ["blue" , "white" , "red" ])
362- elif cmap == "GreenYellowRed" :
363- cmap = colors .LinearSegmentedColormap .from_list ("" , ["green" , "yellow" , "red" ])
364- else :
365- cmap = "viridis"
366-
367- res = {}
368- for layer , values in nn_values .items ():
369- k = len (values )
370- print (f"{ k } values in Layer { layer } ." )
371- if is_square (k ):
372- n = int (math .sqrt (k ))
373- else :
374- n = int (math .sqrt (len (values )) + 1 )
375- padding = np .zeros (n * n - len (values )) # create a zero array for padding
376- print (f"{ len (padding )} padding values added." )
377- values = np .concatenate ((values , padding )) # append the padding to the values
378-
379- print (f"{ len (values )} values in Layer { layer } ." )
380- if absolute :
381- reshaped_values = np .abs (values .reshape ((n , n )))
382- else :
383- reshaped_values = values .reshape ((n , n ))
384-
385- plt .figure (figsize = figsize )
386- plt .imshow (reshaped_values , cmap = cmap ) # use colormap to indicate the values
387- plt .colorbar (label = "Value" )
388- plt .title (f"{ nn_values_names } Plot for { layer } " )
389- plt .show ()
390- # add reshaped_values to the dictionary res
391- res [layer ] = reshaped_values
392- if return_reshaped :
393- return res
394-
395-
396305def plot_nn_values_scatter (
397306 nn_values , nn_values_names = "" , absolute = True , cmap = "gray" , figsize = (6 , 6 ), return_reshaped = False , show = True
398307) -> dict :
@@ -466,32 +375,6 @@ def plot_nn_values_scatter(
466375 return res
467376
468377
469- def visualize_activations_distributions (net , fun_control , batch_size , device = "cpu" , color = "C0" , columns = 2 ) -> None :
470- """
471- Plots a histogram of the activations of a neural network.
472-
473- Args:
474- net (object):
475- A neural network.
476- fun_control (dict):
477- A dictionary with the function control.
478- batch_size (int, optional):
479- The batch size.
480- device (str, optional):
481- The device to use. Defaults to "cpu".
482- color (str, optional):
483- The color to use. Defaults to "C0".
484- columns (int, optional):
485- The number of columns. Defaults to 2.
486-
487- Returns:
488- None
489-
490- """
491- activations = get_activations_full (net , fun_control , batch_size , device )
492- plot_nn_values_hist (activations , net , nn_values_names = "Activations" , color = color , columns = columns )
493-
494-
495378def visualize_weights_distributions (net , color = "C0" , columns = 2 ) -> None :
496379 """
497380 Plot the weights distributions of a neural network.
@@ -546,17 +429,15 @@ def visualize_gradient_distributions(
546429 plot_nn_values_hist (grads , net , nn_values_names = "Gradients" , color = color , columns = columns )
547430
548431
549- def visualize_activations ( net , fun_control , batch_size , device , absolute = True , cmap = "gray" , figsize = (6 , 6 )) -> None :
432+ def visualize_mean_activations ( mean_activations , absolute = True , cmap = "gray" , figsize = (6 , 6 )) -> None :
550433 """
551- Scatter plots the activations of a neural network.
434+ Scatter plots the mean activations of a neural network for each layer.
435+ means_activations is a dictionary with the mean activations of the neural network computed via
436+ the get_activations function.
552437
553438 Args:
554- net (object):
555- A neural network.
556- fun_control (dict):
557- A dictionary with the function control.
558- batch_size (int, optional):
559- The batch size.
439+ mean_activations (dict):
440+ A dictionary with the mean activations of the neural network.
560441 device (str, optional):
561442 The device to use.
562443 absolute (bool, optional):
@@ -569,10 +450,14 @@ def visualize_activations(net, fun_control, batch_size, device, absolute=True, c
569450 Returns:
570451 None
571452
453+ Examples:
454+ >>> from spotpython.plot.xai import get_activations
455+ activations, mean_activations = get_activations(net, fun_control)
456+ visualize_mean_activations(mean_activations
457+
572458 """
573- activations = get_activations (net , fun_control , batch_size , device )
574459 plot_nn_values_scatter (
575- nn_values = activations , nn_values_names = "Average Activations" , absolute = absolute , cmap = cmap , figsize = figsize
460+ nn_values = mean_activations , nn_values_names = "Average Activations" , absolute = absolute , cmap = cmap , figsize = figsize
576461 )
577462
578463
0 commit comments