-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathmodel.py
More file actions
24 lines (22 loc) · 782 Bytes
/
model.py
File metadata and controls
24 lines (22 loc) · 782 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import torch
import torch.nn as nn
import numpy as np
from torchvision import models
class ResNetFashion(nn.Module):
def __init__(self, base="resnet50", num_classes=20, use_pretrained=True):
super(ResNetFashion, self).__init__()
if base=="inception":
self.base = models.inception_v3(pretrained=use_pretrained)
else: #use resnet50 in all other cases
self.base = models.resnet50(pretrained=use_pretrained)
flatten = self.base.fc.in_features
modules = list(self.base.children())[:-1]
self.base = nn.Sequential(*modules)
# we will not freeze any layers
# add any layers in between base and output layer
self.base.add_module('flatten', nn.Flatten())
self.fc = nn.Linear(2048,num_classes)
def forward(self, x):
x = self.base(x)
x = self.fc(x)
return x