-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathtest_state_dict.py
More file actions
31 lines (27 loc) · 1.11 KB
/
test_state_dict.py
File metadata and controls
31 lines (27 loc) · 1.11 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from models.resnet import QuantizedResNet18
from test_models import test
import torch
## Creating quant4 and quant2
bits = 4
model = QuantizedResNet18(bits, 32, pretrained=True)
model.quantize()
torch.save(model.state_dict_quant(bits=bits), 'quant4.pth')
# correct1, correct5, total = test(model, 128)
# print(f"Results from unloaded quant4 model: Top1 {correct1/total}")
bits = 2
model = QuantizedResNet18(bits, 32, pretrained=True)
model.quantize()
torch.save(model.state_dict_quant(bits=bits), 'quant2.pth')
# correct1, correct5, total = test(model, 128)
# print(f"Results from unloaded quant2 model: Top1 {correct1/total}")
## Testing quant4 and quant2
bits=4
model = QuantizedResNet18(bits, 32, pretrained=False)
model.load_state_dict_quant(torch.load('quant4.pth'), bits=bits)
correct1, correct5, total = test(model, 128)
print(f"Results from loaded quant4 model: Top1 {correct1/total}")
bits=2
model = QuantizedResNet18(bits, 32, pretrained=False)
model.load_state_dict_quant(torch.load('quant2.pth'), bits=bits)
correct1, correct5, total = test(model, 128)
print(f"Results from loaded quant2 model: Top1 {correct1/total}")