-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy patheval.py
More file actions
72 lines (56 loc) · 2.42 KB
/
eval.py
File metadata and controls
72 lines (56 loc) · 2.42 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
############################################
# Nicola Altini (2020)
#
# This is the script which contains the code for evaluating the CNN.
# You have to train your CNN before. See train.py
############################################
import torch
import torchvision
from sklearn.metrics import confusion_matrix, classification_report
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
from config import *
from net import Net
from utils import make_pred_on_dataloader, get_classes, subsample_dataset
#%% Create Train Dataloaders
print("Creating training dataset from ", train_folder)
train_dataset = torchvision.datasets.ImageFolder(
root=train_folder,
transform=transform_test
)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size,
shuffle=False, num_workers=num_workers)
val_dataset = torchvision.datasets.ImageFolder(
root=val_folder,
transform=transform_test
)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size,
shuffle=False, num_workers=num_workers)
classes, classes_dict = get_classes()
labels_idxs = []
target_names = []
for el in classes_dict:
target_names.append(el)
labels_idxs.append(classes_dict[el])
#%% Load the net and use eval mode
logs_dir = './logs'
net = Net(in_channels=CHANNELS, out_features=NUM_CLASSES)
PATH = os.path.join(logs_dir, 'dog_vs_cat.pth')
net.load_state_dict(torch.load(PATH))
# Move the net on CUDA
cuda = torch.cuda.is_available()
if cuda:
net = net.cuda()
net = net.eval()
#%% Make prediction on train set
y_true_train, y_pred_train = make_pred_on_dataloader(net, train_dataloader)
#%% Compute metrics on train set
cf_train = confusion_matrix(y_true_train, y_pred_train, labels=labels_idxs)
cr_train = classification_report(y_true_train, y_pred_train, target_names=target_names, output_dict=True)
print(classification_report(y_true_train, y_pred_train, target_names=target_names, output_dict=False))
#%% Make prediction on val set
y_true_test, y_pred_test = make_pred_on_dataloader(net, val_dataloader)
#%% Compute metrics on val set
cf_test = confusion_matrix(y_true_test, y_pred_test, labels=labels_idxs)
cr_test = classification_report(y_true_test, y_pred_test, target_names=target_names, output_dict=True)
print(classification_report(y_true_test, y_pred_test, target_names=target_names, output_dict=False))