diff --git a/pytorch/model.py b/pytorch/model.py index 125ec3a..3a32524 100644 --- a/pytorch/model.py +++ b/pytorch/model.py @@ -28,9 +28,15 @@ class Model(): super().__init__() model_func = getattr(torchvision.models, model_name) self._model = model_func(pretrained=True) - state_dict = torch.hub.load_state_dict_from_url( - 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet101_a1h-36d3f2aa.pth') - self._model.load_state_dict(state_dict) + state_dict = None + if model_name == 'resnet101': + state_dict = torch.hub.load_state_dict_from_url( + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet101_a1h-36d3f2aa.pth') + if model_name == 'resnet50': + state_dict = torch.hub.load_state_dict_from_url( + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50_a1_0-14fe96d1.pth') + if state_dict: + self._model.load_state_dict(state_dict) self._model.fc = torch.nn.Identity() self._model.eval()