-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathscript.py
More file actions
32 lines (26 loc) · 692 Bytes
/
script.py
File metadata and controls
32 lines (26 loc) · 692 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import torch
from src.net import Net
import json
import argparse
from src.load_dataset import load_dataset
PATH = './model.pth'
class_map = {
0: 'female',
1: 'male',
}
parser = argparse.ArgumentParser()
parser.add_argument('path', type=str)
args = parser.parse_args()
net = Net()
net.load_state_dict(torch.load(PATH))
loader, names = load_dataset(args.path)
predicted_general = []
for images in loader:
outputs = net(images)
_, predicted = torch.max(outputs, 1)
predicted_general += predicted.tolist()
data = {
names[i]: class_map[predicted_general[i]] for i in range(len(names))
}
with open('process_results.json', 'w') as outfile:
json.dump(data, outfile)