From 173282622063184b68aaa7ec7918746a506e3c6c Mon Sep 17 00:00:00 2001 From: shiyu22 Date: Wed, 28 Dec 2022 19:37:21 +0800 Subject: [PATCH] Add models Signed-off-by: shiyu22 --- models/yolov5s.pt | 3 +++ yolov5.py | 7 +++++-- 2 files changed, 8 insertions(+), 2 deletions(-) create mode 100644 models/yolov5s.pt diff --git a/models/yolov5s.pt b/models/yolov5s.pt new file mode 100644 index 0000000..cd69a5e --- /dev/null +++ b/models/yolov5s.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8b3b748c1e592ddd8868022e8732fde20025197328490623cc16c6f24d0782ee +size 14808437 diff --git a/yolov5.py b/yolov5.py index 81daa88..4c8d1cd 100644 --- a/yolov5.py +++ b/yolov5.py @@ -1,6 +1,7 @@ import torch import numpy import logging +from pathlib import Path from towhee import register from towhee.operator import NNOperator @@ -11,9 +12,11 @@ logging.getLogger("yolov5").setLevel(logging.WARNING) @register(output_schema=['boxes', 'classes', 'scores']) class Yolov5(NNOperator): - def __init__(self, model_name: str ='yolov5s'): + def __init__(self): super().__init__() - self._model = torch.hub.load("ultralytics/yolov5", model_name, pretrained=True, verbose=False) + model_path = str(Path(__file__).parent / 'models/yolov5s') + self._model = torch.hub.load('ultralytics/yolov5', 'custom', model_path) + # self._model = torch.hub.load("ultralytics/yolov5", model_name, pretrained=True, verbose=False) def __call__(self, img: numpy.ndarray): # Get object detection results with YOLOv5 model