|
|
@ -16,13 +16,14 @@ import logging |
|
|
|
import numpy |
|
|
|
|
|
|
|
from towhee.operator.base import NNOperator, OperatorFlag |
|
|
|
from towhee.types import Image as towheeImage |
|
|
|
from towhee.types.arg import arg, to_image_color |
|
|
|
from towhee import register |
|
|
|
|
|
|
|
import torch |
|
|
|
from torch import nn |
|
|
|
|
|
|
|
from PIL import Image as PILImage |
|
|
|
import cv2 |
|
|
|
|
|
|
|
from timm.data.transforms_factory import create_transform |
|
|
|
from timm.data import resolve_data_config |
|
|
@ -43,9 +44,11 @@ class TimmImage(NNOperator): |
|
|
|
Which model to use for the embeddings. |
|
|
|
num_classes (`int = 1000`): |
|
|
|
Number of classes for classification. |
|
|
|
skip_preprocess (`bool = False`): |
|
|
|
Whether skip image transforms. |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, model_name: str, num_classes: int = 1000) -> None: |
|
|
|
def __init__(self, model_name: str, num_classes: int = 1000, skip_preprocess: bool = False) -> None: |
|
|
|
super().__init__() |
|
|
|
self.device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
|
self.model = create_model(model_name, pretrained=True, num_classes=num_classes) |
|
|
@ -53,19 +56,12 @@ class TimmImage(NNOperator): |
|
|
|
self.model.eval() |
|
|
|
config = resolve_data_config({}, model=self.model) |
|
|
|
self.tfms = create_transform(**config) |
|
|
|
self.skip_tfms = skip_preprocess |
|
|
|
|
|
|
|
def __call__(self, img: numpy.ndarray) -> numpy.ndarray: |
|
|
|
if hasattr(img, 'mode'): |
|
|
|
if img.mode not in ['RGB', 'BGR']: |
|
|
|
log.error(f'Invalid image mode: expect "RGB" or "BGR" but receive "{img.mode}".') |
|
|
|
raise AssertionError(f'Invalid image mode "{img.mode}".') |
|
|
|
elif img.mode == 'BGR': |
|
|
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) |
|
|
|
log.warning('Converting image mode from "BGR" to "RGB" ...') |
|
|
|
else: |
|
|
|
log.warning(f'Image mode is not specified. Using "RGB" now.') |
|
|
|
|
|
|
|
@arg(1, to_image_color('RGB')) |
|
|
|
def __call__(self, img: towheeImage) -> numpy.ndarray: |
|
|
|
img = PILImage.fromarray(img.astype('uint8'), 'RGB') |
|
|
|
if not self.skip_tfms: |
|
|
|
img = self.tfms(img).unsqueeze(0) |
|
|
|
img = img.to(self.device) |
|
|
|
features = self.model.forward_features(img) |
|
|
@ -79,13 +75,13 @@ class TimmImage(NNOperator): |
|
|
|
|
|
|
|
|
|
|
|
# if __name__ == '__main__': |
|
|
|
# from towhee._types import Image |
|
|
|
# from towhee import ops |
|
|
|
# |
|
|
|
# path = '/image/path/or/link' |
|
|
|
# |
|
|
|
# path = '/path/to/image' |
|
|
|
# img = cv2.imread(path) |
|
|
|
# img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) |
|
|
|
# img = Image(img) |
|
|
|
# decoder = ops.image_decode.cv2() |
|
|
|
# img = decoder(path) |
|
|
|
# |
|
|
|
# op = TimmImage('resnet50') |
|
|
|
# out = op(img) |
|
|
|
# print(out) |
|
|
|