logo
Browse Source

change the input format.

main
wxywb 4 years ago
parent
commit
19f1cba4f0
  1. 8
      README.md
  2. 8
      efficientnet_image_embedding.py

8
README.md

@ -25,14 +25,14 @@ __init__(self, model_name: str = 'efficientnet-b7', framework: str = 'pytorch',
- supported types: `str`, default is None, using pretrained weights
```python
__call__(self, img_path: str)
__call__(self, image: 'towhee.types.Image')
```
**Args:**
- img_path:
- the input image path
- supported types: `str`
- image:
- the input image
- supported types: `towhee.types.Image`
**Returns:**

8
efficientnet_image_embedding.py

@ -17,15 +17,15 @@ from PIL import Image
import torch
from torchvision import transforms
import sys
import towhee
from pathlib import Path
import numpy
from towhee.operator import Operator
from towhee.utils.pil_utils import to_pil
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
import os
import warnings
warnings.filterwarnings("ignore")
class EfficientnetImageEmbedding(Operator):
"""
@ -51,8 +51,8 @@ class EfficientnetImageEmbedding(Operator):
config = resolve_data_config({}, model=self.model._model)
self.tfms = create_transform(**config)
def __call__(self, img_path: str) -> NamedTuple('Outputs', [('feature_vector', numpy.ndarray)]):
def __call__(self, image: 'towhee.types.Image') -> NamedTuple('Outputs', [('feature_vector', numpy.ndarray)]):
Outputs = NamedTuple('Outputs', [('feature_vector', numpy.ndarray)])
img = self.tfms(Image.open(img_path)).unsqueeze(0)
img = self.tfms(to_pil(image)).unsqueeze(0)
features = self.model(img)
return Outputs(features.flatten().detach().numpy())

Loading…
Cancel
Save