yolo
copied
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions
24 lines
921 B
24 lines
921 B
import torch
|
|
import numpy
|
|
from pathlib import Path
|
|
|
|
from towhee import register
|
|
from towhee.operator import NNOperator
|
|
|
|
@register(output_schema=['boxes', 'classes', 'scores'])
|
|
class Yolov5(NNOperator):
|
|
def __init__(self):
|
|
super().__init__()
|
|
model_path = str(Path(__file__).parent / 'models/yolov5s.pt')
|
|
self._model = torch.hub.load('ultralytics/yolov5', 'custom', model_path)
|
|
# self._model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True)
|
|
|
|
def __call__(self, img: numpy.ndarray):
|
|
# Get object detection results with YOLOv5 model
|
|
results = self._model(img)
|
|
|
|
boxes = [re[0:4] for re in results.xyxy[0]]
|
|
boxes = [list(map(int, box)) for box in boxes]
|
|
classes = list(results.pandas().xyxy[0].name)
|
|
scores = list(results.pandas().xyxy[0].confidence)
|
|
return [(b, c, s) for b, c, s in zip(boxes, classes, scores)]
|