logo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

28 lines
954 B

import torch
import numpy
import logging
from towhee import register
from towhee.operator import NNOperator
logging.basicConfig()
logging.getLogger().setLevel(logging.WARNING)
logging.getLogger("yolov5").setLevel(logging.WARNING)
logging.getLogger("towhee.engine").setLevel(logging.WARNING)
@register(output_schema=['boxes', 'classes', 'scores'])
class Yolov5(NNOperator):
def __init__(self, model_name: str ='yolov5s'):
super().__init__()
self._model = torch.hub.load("ultralytics/yolov5", model_name, pretrained=True, verbose=False)
def __call__(self, img: numpy.ndarray):
# Get object detection results with YOLOv5 model
results = self._model(img)
boxes = [re[0:4] for re in results.xyxy[0]]
boxes = [list(map(int, box)) for box in boxes]
classes = list(results.pandas().xyxy[0].name)
scores = list(results.pandas().xyxy[0].confidence)
return boxes, classes, scores