33from numpy import array
44from sklearn .pipeline import make_pipeline
55from spotPython .utils .convert import get_Xy_from_df
6+ from spotPython .utils .data import load_data
7+ import torch .nn as nn
8+ import torch .optim as optim
9+ import torch
10+ import os
11+ from torch .utils .data import random_split
612
713
814from spotPython .hyperparameters .values import assign_values
@@ -64,23 +70,100 @@ def check_X_shape(self, X):
6470 raise Exception
6571
6672 def evaluate_model (self , model , fun_control ):
73+ # TODO: config anpassen
6774 try :
75+ lr = fun_control ["lr" ]
76+ checkpoint_dir = fun_control ["checkpoint_dir" ]
77+ data_dir = fun_control ["data_dir" ]
78+
6879 X_train , y_train = get_Xy_from_df (fun_control ["train" ], fun_control ["target_column" ])
6980 X_test , y_test = get_Xy_from_df (fun_control ["test" ], fun_control ["target_column" ])
7081 model .fit (X_train , y_train )
71-
72-
73-
74-
75-
7682 df_preds = model .predict (X_test )
7783 df_eval = fun_control ["metric_sklearn" ](y_test , df_preds )
84+ #
85+ device = "cpu"
86+ # if torch.cuda.is_available():
87+ # device = "cuda:0"
88+ # if torch.cuda.device_count() > 1:
89+ # net = nn.DataParallel(net)
90+ model .to (device )
91+
92+ criterion = nn .CrossEntropyLoss ()
93+ optimizer = optim .SGD (model .parameters (), lr = lr , momentum = 0.9 )
94+
95+ if checkpoint_dir :
96+ model_state , optimizer_state = torch .load (os .path .join (checkpoint_dir , "checkpoint" ))
97+ model .load_state_dict (model_state )
98+ optimizer .load_state_dict (optimizer_state )
99+
100+ trainset , testset = load_data (data_dir )
101+
102+ test_abs = int (len (trainset ) * 0.8 )
103+ train_subset , val_subset = random_split (trainset , [test_abs , len (trainset ) - test_abs ])
104+
105+ trainloader = torch .utils .data .DataLoader (
106+ train_subset , batch_size = int (config ["batch_size" ]), shuffle = True , num_workers = 8
107+ )
108+ valloader = torch .utils .data .DataLoader (
109+ val_subset , batch_size = int (config ["batch_size" ]), shuffle = True , num_workers = 8
110+ )
111+
112+ for epoch in range (10 ): # loop over the dataset multiple times
113+ running_loss = 0.0
114+ epoch_steps = 0
115+ for i , data in enumerate (trainloader , 0 ):
116+ # get the inputs; data is a list of [inputs, labels]
117+ inputs , labels = data
118+ inputs , labels = inputs .to (device ), labels .to (device )
119+
120+ # zero the parameter gradients
121+ optimizer .zero_grad ()
122+
123+ # forward + backward + optimize
124+ outputs = model (inputs )
125+ loss = criterion (outputs , labels )
126+ loss .backward ()
127+ optimizer .step ()
128+
129+ # print statistics
130+ running_loss += loss .item ()
131+ epoch_steps += 1
132+ if i % 2000 == 1999 : # print every 2000 mini-batches
133+ print ("[%d, %5d] loss: %.3f" % (epoch + 1 , i + 1 , running_loss / epoch_steps ))
134+ running_loss = 0.0
135+
136+ # Validation loss
137+ val_loss = 0.0
138+ val_steps = 0
139+ total = 0
140+ correct = 0
141+ for i , data in enumerate (valloader , 0 ):
142+ with torch .no_grad ():
143+ inputs , labels = data
144+ inputs , labels = inputs .to (device ), labels .to (device )
145+
146+ outputs = model (inputs )
147+ _ , predicted = torch .max (outputs .data , 1 )
148+ total += labels .size (0 )
149+ correct += (predicted == labels ).sum ().item ()
150+
151+ loss = criterion (outputs , labels )
152+ val_loss += loss .cpu ().numpy ()
153+ val_steps += 1
154+
155+ # TODO:
156+ # with tune.checkpoint_dir(epoch) as checkpoint_dir:
157+ path = os .path .join (checkpoint_dir , "checkpoint" )
158+ torch .save ((model .state_dict (), optimizer .state_dict ()), path )
159+ df_eval = val_loss / val_steps
160+ df_preds = np .nan
161+ # accuracy = correct / total
78162 except Exception as err :
79163 print (f"Error in fun_sklearn(). Call to evaluate_model failed. { err = } , { type (err )= } " )
80164 df_eval = np .nan
81- df_eval = np .nan
165+ df_preds = np .nan
82166 return df_eval , df_preds
83-
84167
85168 def get_sklearn_df_eval_preds (self , model ):
86169 try :
@@ -92,7 +175,7 @@ def get_sklearn_df_eval_preds(self, model):
92175 df_preds = np .nan
93176 return df_eval , df_preds
94177
95- def fun_sklearn (self , X , fun_control = None ):
178+ def fun_torch (self , X , fun_control = None ):
96179 z_res = np .array ([], dtype = float )
97180 self .fun_control .update (fun_control )
98181 self .check_X_shape (X )
0 commit comments