diff --git a/README.md b/README.md index dda0f47..1f4cdd9 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,59 @@ -# efficientnet-embedding +# Efficientnet Embedding Operator + +Authors: kyle he + +## Overview + +EfficientNets are a family of image classification models, which achieve state-of-the-art accuracy, yet being an order-of-magnitude smaller and faster than previous models[1], which is trained on [imagenet dataset](https://image-net.org/download.php). + + +## Interface + +```python +__init__(self, model_name: str = 'efficientnet-b7', framework: str = 'pytorch', weights_path: str = None) +``` + +**Args:** + +- model_name: + - the model name for embedding + - supported types: `str`, for example 'efficientnet-b7' +- framework: + - the framework of the model + - supported types: `str`, default is 'pytorch' +- weights_path: + - the weights path + - supported types: `str`, default is None, using pretrained weights + +```python +__call__(self, img_path: str) +``` + +**Args:** + +- img_path: + - the input image path + - supported types: `str` + +**Returns:** + +The Operator returns a tuple `Tuple[('feature_vector', numpy.ndarray)]` containing following fields: + +- feature_vector: + - the embedding of the image + - data type: `numpy.ndarray` + +## Requirements + +You can get the required python package by [requirements.txt](./requirements.txt). + +## How it works + +The `towhee/efficientnet-embedding` Operator implements the function of image embedding, which can add to the pipeline. For example, it's the key Operator named embedding_model within [efficientnet-embedding](https://hub.towhee.io/towhee/efficientnet-embedding) pipeline. + + + +## Reference + +[1].https://github.com/lukemelas/EfficientNet-PyTorch#example-feature-extraction -This is another test repo \ No newline at end of file diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..efca5a9 --- /dev/null +++ b/__init__.py @@ -0,0 +1,21 @@ +# Copyright 2021 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +# For requirements. +try: + import efficientnet_pytorch +except ModuleNotFoundError: + os.system('pip install efficientnet_pytorch') diff --git a/efficientnet_embedding.py b/efficientnet_embedding.py new file mode 100644 index 0000000..d45bd86 --- /dev/null +++ b/efficientnet_embedding.py @@ -0,0 +1,48 @@ +# Copyright 2021 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import NamedTuple +from PIL import Image +import torch +from torchvision import transforms +import sys +from pathlib import Path + +from towhee.operator import Operator + + +class EfficientnetEmbeddingOperator(Operator): + """ + Embedding extractor using efficientnet. + Args: + model_name (`string`): + Model name. + weights_path (`string`): + Path to local weights. + """ + + def __init__(self, model_name: str = 'efficientnet-b7', framework: str = 'pytorch', weights_path: str = None) -> None: + super().__init__() + sys.path.append(str(Path(__file__).parent)) + if framework == 'pytorch': + from pytorch.model import Model + self.model = Model(model_name, weights_path) + self.tfms = transforms.Compose([transforms.Resize(224), transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ]) + + def __call__(self, img_path: str) -> NamedTuple('Outputs', [('embedding', torch.Tensor)]): + Outputs = NamedTuple('Outputs', [('embedding', torch.Tensor)]) + img = self.tfms(Image.open(img_path)).unsqueeze(0) + features = self.model._model.extract_features(img) + return Outputs(features.flatten().detach().numpy()) diff --git a/efficientnet_embedding.yaml b/efficientnet_embedding.yaml new file mode 100644 index 0000000..33afc2d --- /dev/null +++ b/efficientnet_embedding.yaml @@ -0,0 +1,13 @@ +name: 'efficientnet-embedding' +labels: + recommended_framework: pytorch1.2.0 + class: image-embedding + others: efficientnet +operator: 'towhee/efficientnet-embedding' +init: + model_name: str +call: + input: + img_path: str + output: + feature_vector: numpy.ndarray diff --git a/pytorch/__init__.py b/pytorch/__init__.py new file mode 100644 index 0000000..37f5bd7 --- /dev/null +++ b/pytorch/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2021 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/pytorch/model.py b/pytorch/model.py new file mode 100644 index 0000000..d973dac --- /dev/null +++ b/pytorch/model.py @@ -0,0 +1,39 @@ +# Copyright 2021 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import NamedTuple + +import numpy +import torch +from efficientnet_pytorch import EfficientNet + + +class Model(): + """ + PyTorch model class + """ + def __init__(self, model_name: str, weights_path: str): + super().__init__() + self._model = EfficientNet.from_pretrained(model_name=model_name, weights_path=weights_path) + self._model.eval() + + def __call__(self, img_tensor: torch.Tensor): + return self._model(img_tensor).detach().numpy() + + def train(self): + """ + For training model + """ + pass diff --git a/readme_res/operator.png b/readme_res/operator.png new file mode 100644 index 0000000..5d64176 Binary files /dev/null and b/readme_res/operator.png differ diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..e69de29