logo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

29 lines
1.0 KiB

import torch
import numpy
import logging
from pathlib import Path
from towhee import register
from towhee.operator import NNOperator
logging.basicConfig(level=logging.WARNING)
logging.getLogger().setLevel(logging.WARNING)
logging.getLogger("yolov5").setLevel(logging.WARNING)
@register(output_schema=['boxes', 'classes', 'scores'])
class Yolov5(NNOperator):
def __init__(self):
super().__init__()
model_path = str(Path(__file__).parent / 'models/yolov5s')
self._model = torch.hub.load('ultralytics/yolov5', 'custom', model_path)
# self._model = torch.hub.load("ultralytics/yolov5", model_name, pretrained=True, verbose=False)
def __call__(self, img: numpy.ndarray):
# Get object detection results with YOLOv5 model
results = self._model(img)
boxes = [re[0:4] for re in results.xyxy[0]]
boxes = [list(map(int, box)) for box in boxes]
classes = list(results.pandas().xyxy[0].name)
scores = list(results.pandas().xyxy[0].confidence)
return boxes, classes, scores