1818from spotpython .data .lightdatamodule import LightDataModule
1919
2020
21- def check_for_nans (data , layer_index ):
21+ def check_for_nans (data , layer_index ) -> bool :
2222 """Checks for NaN values in the tensor data.
2323
2424 Args:
@@ -34,14 +34,16 @@ def check_for_nans(data, layer_index):
3434 return False
3535
3636
37- def get_activations (net , fun_control , batch_size , device = "cpu" ) -> tuple :
37+ def get_activations (net , fun_control , batch_size , device = "cpu" , normalize = True ) -> tuple :
3838 """Computes the activations for each layer of the network and
3939 the mean activations for each layer. Both are returned as a dictionary.
4040
4141 Args:
4242 net (nn.Module): The neural network model.
4343 fun_control (dict): A dictionary containing the dataset.
4444 device (str): The device to run the model on. Defaults to "cpu".
45+ batch_size (int): The batch size for the data loader.
46+ normalize (bool): Whether to normalize the input data. Defaults to True.
4547
4648 Returns:
4749 tuple: A tuple containing the activations and mean activations for each layer.
@@ -54,12 +56,19 @@ def get_activations(net, fun_control, batch_size, device="cpu") -> tuple:
5456 mean_activations = {}
5557 net .eval () # Set the model to evaluation mode
5658 dataset = fun_control ["data_set" ]
57- dataloader = DataLoader (dataset , batch_size = batch_size , shuffle = False )
58- inputs , _ = next (iter (dataloader ))
59+ data_module = LightDataModule (
60+ dataset = dataset ,
61+ batch_size = batch_size ,
62+ test_size = fun_control ["test_size" ],
63+ scaler = fun_control ["scaler" ],
64+ verbosity = 10 ,
65+ )
66+ data_module .setup (stage = "test" )
67+ test_loader = data_module .test_dataloader ()
68+ inputs , _ = next (iter (test_loader ))
5969 inputs = inputs .to (device )
60-
61- # Normalize input data
62- inputs = (inputs - inputs .mean ()) / inputs .std ()
70+ if normalize :
71+ inputs = (inputs - inputs .mean ()) / inputs .std ()
6372
6473 with torch .no_grad ():
6574 inputs = inputs .view (inputs .size (0 ), - 1 )
@@ -190,7 +199,7 @@ def get_weights(net, return_index=False) -> dict:
190199 return weights
191200
192201
193- def get_gradients (net , fun_control , batch_size , device = "cpu" ) -> dict :
202+ def get_gradients (net , fun_control , batch_size , device = "cpu" , normalize = True ) -> dict :
194203 """
195204 Get the gradients of a neural network.
196205
@@ -203,6 +212,8 @@ def get_gradients(net, fun_control, batch_size, device="cpu") -> dict:
203212 The batch size.
204213 device (str, optional):
205214 The device to use. Defaults to "cpu".
215+ normalize (bool, optional):
216+ Whether to normalize the input data. Defaults to True.
206217
207218 Returns:
208219 dict: A dictionary with the gradients of the neural network.
@@ -245,8 +256,18 @@ def get_gradients(net, fun_control, batch_size, device="cpu") -> dict:
245256 """
246257 net .eval ()
247258 dataset = fun_control ["data_set" ]
248- dataloader = DataLoader (dataset , batch_size = batch_size , shuffle = False )
249- inputs , targets = next (iter (dataloader ))
259+ data_module = LightDataModule (
260+ dataset = dataset ,
261+ batch_size = batch_size ,
262+ test_size = fun_control ["test_size" ],
263+ scaler = fun_control ["scaler" ],
264+ verbosity = 10 ,
265+ )
266+ data_module .setup (stage = "test" )
267+ test_loader = data_module .test_dataloader ()
268+ inputs , targets = next (iter (test_loader ))
269+ if normalize :
270+ inputs = (inputs - inputs .mean ()) / inputs .std ()
250271 inputs , targets = inputs .to (device ), targets .to (device )
251272 # Pass one batch through the network, and calculate the gradients for the weights
252273 net .zero_grad ()
@@ -396,7 +417,16 @@ def visualize_weights_distributions(net, color="C0", columns=2) -> None:
396417
397418
398419def visualize_gradient_distributions (
399- net , fun_control , batch_size , device = "cpu" , color = "C0" , xlabel = None , stat = "count" , use_kde = True , columns = 2
420+ net ,
421+ fun_control ,
422+ batch_size ,
423+ device = "cpu" ,
424+ color = "C0" ,
425+ xlabel = None ,
426+ stat = "count" ,
427+ use_kde = True ,
428+ columns = 2 ,
429+ normalize = True ,
400430) -> None :
401431 """
402432 Plot the gradients distributions of a neural network.
@@ -420,12 +450,14 @@ def visualize_gradient_distributions(
420450 Whether to use kde. Defaults to True.
421451 columns (int, optional):
422452 The number of columns. Defaults to 2.
453+ normalize (bool, optional):
454+ Whether to normalize the input data. Defaults to True.
423455
424456 Returns:
425457 None
426458
427459 """
428- grads = get_gradients (net , fun_control , batch_size , device )
460+ grads = get_gradients (net , fun_control , batch_size , device , normalize = normalize )
429461 plot_nn_values_hist (grads , net , nn_values_names = "Gradients" , color = color , columns = columns )
430462
431463
@@ -438,8 +470,6 @@ def visualize_mean_activations(mean_activations, absolute=True, cmap="gray", fig
438470 Args:
439471 mean_activations (dict):
440472 A dictionary with the mean activations of the neural network.
441- device (str, optional):
442- The device to use.
443473 absolute (bool, optional):
444474 Whether to use the absolute values. Defaults to True.
445475 cmap (str, optional):
@@ -483,7 +513,9 @@ def visualize_weights(net, absolute=True, cmap="gray", figsize=(6, 6)) -> None:
483513 plot_nn_values_scatter (nn_values = weights , nn_values_names = "Weights" , absolute = absolute , cmap = cmap , figsize = figsize )
484514
485515
486- def visualize_gradients (net , fun_control , batch_size , absolute = True , cmap = "gray" , figsize = (6 , 6 ), device = "cpu" ) -> None :
516+ def visualize_gradients (
517+ net , fun_control , batch_size , absolute = True , cmap = "gray" , figsize = (6 , 6 ), device = "cpu" , normalize = True
518+ ) -> None :
487519 """
488520 Scatter plots the gradients of a neural network.
489521
@@ -502,15 +534,18 @@ def visualize_gradients(net, fun_control, batch_size, absolute=True, cmap="gray"
502534 The figure size. Defaults to (6, 6).
503535 device (str, optional):
504536 The device to use. Defaults to "cpu".
537+ normalize (bool, optional):
538+ Whether to normalize the input data. Defaults to True.
505539
506540 Returns:
507541 None
508542 """
509543 grads = get_gradients (
510- net ,
511- fun_control ,
544+ net = net ,
545+ fun_control = fun_control ,
512546 batch_size = batch_size ,
513547 device = device ,
548+ normalize = normalize ,
514549 )
515550 plot_nn_values_scatter (nn_values = grads , nn_values_names = "Gradients" , absolute = absolute , cmap = cmap , figsize = figsize )
516551
@@ -523,6 +558,7 @@ def get_attributions(
523558 abs_attr = True ,
524559 n_rel = 5 ,
525560 device = "cpu" ,
561+ normalize = True ,
526562) -> pd .DataFrame :
527563 """Get the attributions of a neural network.
528564
@@ -541,6 +577,8 @@ def get_attributions(
541577 The number of relevant features. Defaults to 5.
542578 device (str, optional):
543579 The device to use. Defaults to "cpu".
580+ normalize (bool, optional):
581+ Whether to normalize the input data. Defaults to True.
544582
545583 Returns:
546584 pd.DataFrame (object): A DataFrame with the attributions.
@@ -597,6 +635,8 @@ def get_attributions(
597635 """
598636 )
599637 for inputs , _ in test_loader :
638+ if normalize :
639+ inputs = (inputs - inputs .mean ()) / inputs .std ()
600640 inputs .requires_grad_ ()
601641 attributions = attr .attribute (inputs , return_convergence_delta = False , baselines = baseline )
602642 if total_attributions is None :
@@ -670,7 +710,7 @@ def is_square(n) -> bool:
670710 return n == int (math .sqrt (n )) ** 2
671711
672712
673- def get_layer_conductance (spot_tuner , fun_control , layer_idx , device = "cpu" ) -> np .ndarray :
713+ def get_layer_conductance (spot_tuner , fun_control , layer_idx , device = "cpu" , normalize = True ) -> np .ndarray :
674714 """
675715 Compute the average layer conductance attributions for a specified layer in the model.
676716
@@ -683,6 +723,8 @@ def get_layer_conductance(spot_tuner, fun_control, layer_idx, device="cpu") -> n
683723 Index of the layer for which to compute layer conductance attributions.
684724 device (str, optional):
685725 The device to use. Defaults to "cpu".
726+ normalize (bool, optional):
727+ Whether to normalize the input data. Defaults to True.
686728
687729 Returns:
688730 numpy.ndarray:
@@ -710,15 +752,15 @@ def get_layer_conductance(spot_tuner, fun_control, layer_idx, device="cpu") -> n
710752 if feature_names is None :
711753 feature_names = [f"x{ i } " for i in range (n_features )]
712754 batch_size = config ["batch_size" ]
713- # train_loader = DataLoader(dataset, batch_size=batch_size)
714755 test_loader = DataLoader (dataset , batch_size = batch_size )
715-
716756 total_layer_attributions = None
717757 layers = model .layers
718758 print ("Conductance analysis for layer: " , layers [layer_idx ])
719759 lc = LayerConductance (model , layers [layer_idx ])
720760
721761 for inputs , labels in test_loader :
762+ if normalize :
763+ inputs = (inputs - inputs .mean ()) / inputs .std ()
722764 lc_attr_test = lc .attribute (inputs , n_steps = 10 , attribute_to_layer_input = True )
723765 if total_layer_attributions is None :
724766 total_layer_attributions = lc_attr_test
@@ -844,8 +886,6 @@ def sort_layers(data_dict) -> dict:
844886 """
845887 # Use a lambda function to extract the number X from "Layer X" and sort based on that number
846888 sorted_items = sorted (data_dict .items (), key = lambda item : int (item [0 ].split ()[1 ]))
847-
848889 # Create a new dictionary from the sorted items
849890 sorted_dict = dict (sorted_items )
850-
851891 return sorted_dict
0 commit comments