towhee
/
efficientnet-image-embedding
copied
8 changed files with 192 additions and 2 deletions
@ -1,3 +1,59 @@ |
|||||
# efficientnet-embedding |
|
||||
|
# Efficientnet Embedding Operator |
||||
|
|
||||
|
Authors: kyle he |
||||
|
|
||||
|
## Overview |
||||
|
|
||||
|
EfficientNets are a family of image classification models, which achieve state-of-the-art accuracy, yet being an order-of-magnitude smaller and faster than previous models[1], which is trained on [imagenet dataset](https://image-net.org/download.php). |
||||
|
|
||||
|
|
||||
|
## Interface |
||||
|
|
||||
|
```python |
||||
|
__init__(self, model_name: str = 'efficientnet-b7', framework: str = 'pytorch', weights_path: str = None) |
||||
|
``` |
||||
|
|
||||
|
**Args:** |
||||
|
|
||||
|
- model_name: |
||||
|
- the model name for embedding |
||||
|
- supported types: `str`, for example 'efficientnet-b7' |
||||
|
- framework: |
||||
|
- the framework of the model |
||||
|
- supported types: `str`, default is 'pytorch' |
||||
|
- weights_path: |
||||
|
- the weights path |
||||
|
- supported types: `str`, default is None, using pretrained weights |
||||
|
|
||||
|
```python |
||||
|
__call__(self, img_path: str) |
||||
|
``` |
||||
|
|
||||
|
**Args:** |
||||
|
|
||||
|
- img_path: |
||||
|
- the input image path |
||||
|
- supported types: `str` |
||||
|
|
||||
|
**Returns:** |
||||
|
|
||||
|
The Operator returns a tuple `Tuple[('feature_vector', numpy.ndarray)]` containing following fields: |
||||
|
|
||||
|
- feature_vector: |
||||
|
- the embedding of the image |
||||
|
- data type: `numpy.ndarray` |
||||
|
|
||||
|
## Requirements |
||||
|
|
||||
|
You can get the required python package by [requirements.txt](./requirements.txt). |
||||
|
|
||||
|
## How it works |
||||
|
|
||||
|
The `towhee/efficientnet-embedding` Operator implements the function of image embedding, which can add to the pipeline. For example, it's the key Operator named embedding_model within [efficientnet-embedding](https://hub.towhee.io/towhee/efficientnet-embedding) pipeline. |
||||
|
|
||||
|
|
||||
|
|
||||
|
## Reference |
||||
|
|
||||
|
[1].https://github.com/lukemelas/EfficientNet-PyTorch#example-feature-extraction |
||||
|
|
||||
This is another test repo |
|
@ -0,0 +1,21 @@ |
|||||
|
# Copyright 2021 Zilliz. All rights reserved. |
||||
|
# |
||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
|
# you may not use this file except in compliance with the License. |
||||
|
# You may obtain a copy of the License at |
||||
|
# |
||||
|
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
|
# |
||||
|
# Unless required by applicable law or agreed to in writing, software |
||||
|
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
|
# See the License for the specific language governing permissions and |
||||
|
# limitations under the License. |
||||
|
|
||||
|
import os |
||||
|
|
||||
|
# For requirements. |
||||
|
try: |
||||
|
import efficientnet_pytorch |
||||
|
except ModuleNotFoundError: |
||||
|
os.system('pip install efficientnet_pytorch') |
@ -0,0 +1,48 @@ |
|||||
|
# Copyright 2021 Zilliz. All rights reserved. |
||||
|
# |
||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
|
# you may not use this file except in compliance with the License. |
||||
|
# You may obtain a copy of the License at |
||||
|
# |
||||
|
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
|
# |
||||
|
# Unless required by applicable law or agreed to in writing, software |
||||
|
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
|
# See the License for the specific language governing permissions and |
||||
|
# limitations under the License. |
||||
|
|
||||
|
from typing import NamedTuple |
||||
|
from PIL import Image |
||||
|
import torch |
||||
|
from torchvision import transforms |
||||
|
import sys |
||||
|
from pathlib import Path |
||||
|
|
||||
|
from towhee.operator import Operator |
||||
|
|
||||
|
|
||||
|
class EfficientnetEmbeddingOperator(Operator): |
||||
|
""" |
||||
|
Embedding extractor using efficientnet. |
||||
|
Args: |
||||
|
model_name (`string`): |
||||
|
Model name. |
||||
|
weights_path (`string`): |
||||
|
Path to local weights. |
||||
|
""" |
||||
|
|
||||
|
def __init__(self, model_name: str = 'efficientnet-b7', framework: str = 'pytorch', weights_path: str = None) -> None: |
||||
|
super().__init__() |
||||
|
sys.path.append(str(Path(__file__).parent)) |
||||
|
if framework == 'pytorch': |
||||
|
from pytorch.model import Model |
||||
|
self.model = Model(model_name, weights_path) |
||||
|
self.tfms = transforms.Compose([transforms.Resize(224), transforms.ToTensor(), |
||||
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ]) |
||||
|
|
||||
|
def __call__(self, img_path: str) -> NamedTuple('Outputs', [('embedding', torch.Tensor)]): |
||||
|
Outputs = NamedTuple('Outputs', [('embedding', torch.Tensor)]) |
||||
|
img = self.tfms(Image.open(img_path)).unsqueeze(0) |
||||
|
features = self.model._model.extract_features(img) |
||||
|
return Outputs(features.flatten().detach().numpy()) |
@ -0,0 +1,13 @@ |
|||||
|
name: 'efficientnet-embedding' |
||||
|
labels: |
||||
|
recommended_framework: pytorch1.2.0 |
||||
|
class: image-embedding |
||||
|
others: efficientnet |
||||
|
operator: 'towhee/efficientnet-embedding' |
||||
|
init: |
||||
|
model_name: str |
||||
|
call: |
||||
|
input: |
||||
|
img_path: str |
||||
|
output: |
||||
|
feature_vector: numpy.ndarray |
@ -0,0 +1,13 @@ |
|||||
|
# Copyright 2021 Zilliz. All rights reserved. |
||||
|
# |
||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
|
# you may not use this file except in compliance with the License. |
||||
|
# You may obtain a copy of the License at |
||||
|
# |
||||
|
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
|
# |
||||
|
# Unless required by applicable law or agreed to in writing, software |
||||
|
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
|
# See the License for the specific language governing permissions and |
||||
|
# limitations under the License. |
@ -0,0 +1,39 @@ |
|||||
|
# Copyright 2021 Zilliz. All rights reserved. |
||||
|
# |
||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
|
# you may not use this file except in compliance with the License. |
||||
|
# You may obtain a copy of the License at |
||||
|
# |
||||
|
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
|
# |
||||
|
# Unless required by applicable law or agreed to in writing, software |
||||
|
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
|
# See the License for the specific language governing permissions and |
||||
|
# limitations under the License. |
||||
|
|
||||
|
|
||||
|
from typing import NamedTuple |
||||
|
|
||||
|
import numpy |
||||
|
import torch |
||||
|
from efficientnet_pytorch import EfficientNet |
||||
|
|
||||
|
|
||||
|
class Model(): |
||||
|
""" |
||||
|
PyTorch model class |
||||
|
""" |
||||
|
def __init__(self, model_name: str, weights_path: str): |
||||
|
super().__init__() |
||||
|
self._model = EfficientNet.from_pretrained(model_name=model_name, weights_path=weights_path) |
||||
|
self._model.eval() |
||||
|
|
||||
|
def __call__(self, img_tensor: torch.Tensor): |
||||
|
return self._model(img_tensor).detach().numpy() |
||||
|
|
||||
|
def train(self): |
||||
|
""" |
||||
|
For training model |
||||
|
""" |
||||
|
pass |
After Width: | Height: | Size: 262 KiB |
Loading…
Reference in new issue