logo
Browse Source

change the input format.

main
wxywb 4 years ago
parent
commit
19f1cba4f0
  1. 8
      README.md
  2. 8
      efficientnet_image_embedding.py

8
README.md

@ -25,14 +25,14 @@ __init__(self, model_name: str = 'efficientnet-b7', framework: str = 'pytorch',
- supported types: `str`, default is None, using pretrained weights - supported types: `str`, default is None, using pretrained weights
```python ```python
__call__(self, img_path: str)
__call__(self, image: 'towhee.types.Image')
``` ```
**Args:** **Args:**
- img_path:
- the input image path
- supported types: `str`
- image:
- the input image
- supported types: `towhee.types.Image`
**Returns:** **Returns:**

8
efficientnet_image_embedding.py

@ -17,15 +17,15 @@ from PIL import Image
import torch import torch
from torchvision import transforms from torchvision import transforms
import sys import sys
import towhee
from pathlib import Path from pathlib import Path
import numpy import numpy
from towhee.operator import Operator from towhee.operator import Operator
from towhee.utils.pil_utils import to_pil
from timm.data import resolve_data_config from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform from timm.data.transforms_factory import create_transform
import os import os
import warnings
warnings.filterwarnings("ignore")
class EfficientnetImageEmbedding(Operator): class EfficientnetImageEmbedding(Operator):
""" """
@ -51,8 +51,8 @@ class EfficientnetImageEmbedding(Operator):
config = resolve_data_config({}, model=self.model._model) config = resolve_data_config({}, model=self.model._model)
self.tfms = create_transform(**config) self.tfms = create_transform(**config)
def __call__(self, img_path: str) -> NamedTuple('Outputs', [('feature_vector', numpy.ndarray)]):
def __call__(self, image: 'towhee.types.Image') -> NamedTuple('Outputs', [('feature_vector', numpy.ndarray)]):
Outputs = NamedTuple('Outputs', [('feature_vector', numpy.ndarray)]) Outputs = NamedTuple('Outputs', [('feature_vector', numpy.ndarray)])
img = self.tfms(Image.open(img_path)).unsqueeze(0)
img = self.tfms(to_pil(image)).unsqueeze(0)
features = self.model(img) features = self.model(img)
return Outputs(features.flatten().detach().numpy()) return Outputs(features.flatten().detach().numpy())

Loading…
Cancel
Save