towhee
/
efficientnet-image-embedding
copied
8 changed files with 192 additions and 2 deletions
@ -1,3 +1,59 @@ |
|||
# efficientnet-embedding |
|||
# Efficientnet Embedding Operator |
|||
|
|||
Authors: kyle he |
|||
|
|||
## Overview |
|||
|
|||
EfficientNets are a family of image classification models, which achieve state-of-the-art accuracy, yet being an order-of-magnitude smaller and faster than previous models[1], which is trained on [imagenet dataset](https://image-net.org/download.php). |
|||
|
|||
|
|||
## Interface |
|||
|
|||
```python |
|||
__init__(self, model_name: str = 'efficientnet-b7', framework: str = 'pytorch', weights_path: str = None) |
|||
``` |
|||
|
|||
**Args:** |
|||
|
|||
- model_name: |
|||
- the model name for embedding |
|||
- supported types: `str`, for example 'efficientnet-b7' |
|||
- framework: |
|||
- the framework of the model |
|||
- supported types: `str`, default is 'pytorch' |
|||
- weights_path: |
|||
- the weights path |
|||
- supported types: `str`, default is None, using pretrained weights |
|||
|
|||
```python |
|||
__call__(self, img_path: str) |
|||
``` |
|||
|
|||
**Args:** |
|||
|
|||
- img_path: |
|||
- the input image path |
|||
- supported types: `str` |
|||
|
|||
**Returns:** |
|||
|
|||
The Operator returns a tuple `Tuple[('feature_vector', numpy.ndarray)]` containing following fields: |
|||
|
|||
- feature_vector: |
|||
- the embedding of the image |
|||
- data type: `numpy.ndarray` |
|||
|
|||
## Requirements |
|||
|
|||
You can get the required python package by [requirements.txt](./requirements.txt). |
|||
|
|||
## How it works |
|||
|
|||
The `towhee/efficientnet-embedding` Operator implements the function of image embedding, which can add to the pipeline. For example, it's the key Operator named embedding_model within [efficientnet-embedding](https://hub.towhee.io/towhee/efficientnet-embedding) pipeline. |
|||
|
|||
|
|||
|
|||
## Reference |
|||
|
|||
[1].https://github.com/lukemelas/EfficientNet-PyTorch#example-feature-extraction |
|||
|
|||
This is another test repo |
@ -0,0 +1,21 @@ |
|||
# Copyright 2021 Zilliz. All rights reserved. |
|||
# |
|||
# Licensed under the Apache License, Version 2.0 (the "License"); |
|||
# you may not use this file except in compliance with the License. |
|||
# You may obtain a copy of the License at |
|||
# |
|||
# http://www.apache.org/licenses/LICENSE-2.0 |
|||
# |
|||
# Unless required by applicable law or agreed to in writing, software |
|||
# distributed under the License is distributed on an "AS IS" BASIS, |
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|||
# See the License for the specific language governing permissions and |
|||
# limitations under the License. |
|||
|
|||
import os |
|||
|
|||
# For requirements. |
|||
try: |
|||
import efficientnet_pytorch |
|||
except ModuleNotFoundError: |
|||
os.system('pip install efficientnet_pytorch') |
@ -0,0 +1,48 @@ |
|||
# Copyright 2021 Zilliz. All rights reserved. |
|||
# |
|||
# Licensed under the Apache License, Version 2.0 (the "License"); |
|||
# you may not use this file except in compliance with the License. |
|||
# You may obtain a copy of the License at |
|||
# |
|||
# http://www.apache.org/licenses/LICENSE-2.0 |
|||
# |
|||
# Unless required by applicable law or agreed to in writing, software |
|||
# distributed under the License is distributed on an "AS IS" BASIS, |
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|||
# See the License for the specific language governing permissions and |
|||
# limitations under the License. |
|||
|
|||
from typing import NamedTuple |
|||
from PIL import Image |
|||
import torch |
|||
from torchvision import transforms |
|||
import sys |
|||
from pathlib import Path |
|||
|
|||
from towhee.operator import Operator |
|||
|
|||
|
|||
class EfficientnetEmbeddingOperator(Operator): |
|||
""" |
|||
Embedding extractor using efficientnet. |
|||
Args: |
|||
model_name (`string`): |
|||
Model name. |
|||
weights_path (`string`): |
|||
Path to local weights. |
|||
""" |
|||
|
|||
def __init__(self, model_name: str = 'efficientnet-b7', framework: str = 'pytorch', weights_path: str = None) -> None: |
|||
super().__init__() |
|||
sys.path.append(str(Path(__file__).parent)) |
|||
if framework == 'pytorch': |
|||
from pytorch.model import Model |
|||
self.model = Model(model_name, weights_path) |
|||
self.tfms = transforms.Compose([transforms.Resize(224), transforms.ToTensor(), |
|||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ]) |
|||
|
|||
def __call__(self, img_path: str) -> NamedTuple('Outputs', [('embedding', torch.Tensor)]): |
|||
Outputs = NamedTuple('Outputs', [('embedding', torch.Tensor)]) |
|||
img = self.tfms(Image.open(img_path)).unsqueeze(0) |
|||
features = self.model._model.extract_features(img) |
|||
return Outputs(features.flatten().detach().numpy()) |
@ -0,0 +1,13 @@ |
|||
name: 'efficientnet-embedding' |
|||
labels: |
|||
recommended_framework: pytorch1.2.0 |
|||
class: image-embedding |
|||
others: efficientnet |
|||
operator: 'towhee/efficientnet-embedding' |
|||
init: |
|||
model_name: str |
|||
call: |
|||
input: |
|||
img_path: str |
|||
output: |
|||
feature_vector: numpy.ndarray |
@ -0,0 +1,13 @@ |
|||
# Copyright 2021 Zilliz. All rights reserved. |
|||
# |
|||
# Licensed under the Apache License, Version 2.0 (the "License"); |
|||
# you may not use this file except in compliance with the License. |
|||
# You may obtain a copy of the License at |
|||
# |
|||
# http://www.apache.org/licenses/LICENSE-2.0 |
|||
# |
|||
# Unless required by applicable law or agreed to in writing, software |
|||
# distributed under the License is distributed on an "AS IS" BASIS, |
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|||
# See the License for the specific language governing permissions and |
|||
# limitations under the License. |
@ -0,0 +1,39 @@ |
|||
# Copyright 2021 Zilliz. All rights reserved. |
|||
# |
|||
# Licensed under the Apache License, Version 2.0 (the "License"); |
|||
# you may not use this file except in compliance with the License. |
|||
# You may obtain a copy of the License at |
|||
# |
|||
# http://www.apache.org/licenses/LICENSE-2.0 |
|||
# |
|||
# Unless required by applicable law or agreed to in writing, software |
|||
# distributed under the License is distributed on an "AS IS" BASIS, |
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|||
# See the License for the specific language governing permissions and |
|||
# limitations under the License. |
|||
|
|||
|
|||
from typing import NamedTuple |
|||
|
|||
import numpy |
|||
import torch |
|||
from efficientnet_pytorch import EfficientNet |
|||
|
|||
|
|||
class Model(): |
|||
""" |
|||
PyTorch model class |
|||
""" |
|||
def __init__(self, model_name: str, weights_path: str): |
|||
super().__init__() |
|||
self._model = EfficientNet.from_pretrained(model_name=model_name, weights_path=weights_path) |
|||
self._model.eval() |
|||
|
|||
def __call__(self, img_tensor: torch.Tensor): |
|||
return self._model(img_tensor).detach().numpy() |
|||
|
|||
def train(self): |
|||
""" |
|||
For training model |
|||
""" |
|||
pass |
After Width: | Height: | Size: 262 KiB |
Loading…
Reference in new issue