diff --git a/yolov5.py b/yolov5.py index 4c8d1cd..a1af818 100644 --- a/yolov5.py +++ b/yolov5.py @@ -14,9 +14,9 @@ logging.getLogger("yolov5").setLevel(logging.WARNING) class Yolov5(NNOperator): def __init__(self): super().__init__() - model_path = str(Path(__file__).parent / 'models/yolov5s') + model_path = str(Path(__file__).parent / 'models/yolov5s.pt') self._model = torch.hub.load('ultralytics/yolov5', 'custom', model_path) - # self._model = torch.hub.load("ultralytics/yolov5", model_name, pretrained=True, verbose=False) + # self._model = torch.hub.load("ultralytics/yolov5", 'yolov5s', pretrained=True, verbose=False) def __call__(self, img: numpy.ndarray): # Get object detection results with YOLOv5 model