Browse Source
update the BLIP operator.
Signed-off-by: wxywb <xy.wang@zilliz.com>
main
wxywb
2 years ago
3 changed files with
40 additions and
7 deletions
-
README.md
-
__init__.py
-
blip.py
|
|
@ -77,6 +77,38 @@ Create the operator via the following factory method |
|
|
|
|
|
|
|
An image-text embedding operator takes a [towhee image](link/to/towhee/image/api/doc) or string as input and generate an embedding in ndarray. |
|
|
|
|
|
|
|
***save_model(format='pytorch', path='default')*** |
|
|
|
|
|
|
|
Save model to local with specified format. |
|
|
|
|
|
|
|
**Parameters:** |
|
|
|
|
|
|
|
***format***: *str* |
|
|
|
|
|
|
|
The format of saved model, defaults to 'pytorch'. |
|
|
|
|
|
|
|
***path***: *str* |
|
|
|
|
|
|
|
The path where model is saved to. By default, it will save model to the operator directory. |
|
|
|
|
|
|
|
|
|
|
|
```python |
|
|
|
from towhee import ops |
|
|
|
|
|
|
|
op = ops.image_text_embedding.blip(model_name='blip_itm_base', modality='image').get_op() |
|
|
|
op.save_model('onnx', 'test.onnx') |
|
|
|
``` |
|
|
|
<br /> |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
**Parameters:** |
|
|
|
|
|
|
|
***data:*** *towhee.types.Image (a sub-class of numpy.ndarray)* or *str* |
|
|
|
|
|
|
|
The data (image or text based on specified modality) to generate embedding. |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
**Parameters:** |
|
|
|
|
|
|
|
|
|
@ -15,5 +15,5 @@ |
|
|
|
from .blip import Blip |
|
|
|
|
|
|
|
|
|
|
|
def blip(model_name: str, modality: str): |
|
|
|
return Blip(model_name, modality) |
|
|
|
def blip(model_name: str, modality: str, device:str = None, checkpoint_path:str = None): |
|
|
|
return Blip(model_name, modality, device, checkpoint_path) |
|
|
|
|
|
@ -76,7 +76,7 @@ class BLIPModelText(nn.Module): |
|
|
|
return text_features |
|
|
|
|
|
|
|
class Model: |
|
|
|
def __init__(self, model_name, modality, checkpoint_path, device): |
|
|
|
def __init__(self, model_name, modality, device, checkpoint_path): |
|
|
|
self.model = create_model(model_name, modality, checkpoint_path, device) |
|
|
|
self.device = device |
|
|
|
|
|
|
@ -165,10 +165,11 @@ class Blip(NNOperator): |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def supported_model_names(format: str = None): |
|
|
|
if format == 'pytorch' or format == 'torchscript' or format == 'onnx': |
|
|
|
model_list = [ |
|
|
|
'blip_itm_base', |
|
|
|
] |
|
|
|
full_list = ['blip_itm_base'] |
|
|
|
if format == None: |
|
|
|
model_list = full_list |
|
|
|
elif format == 'pytorch' or format == 'torchscript' or format == 'onnx': |
|
|
|
model_list = full_list |
|
|
|
else: |
|
|
|
log.error(f'Invalid format "{format}". Currently supported formats: "pytorch", "torchscript".') |
|
|
|
return model_list |
|
|
|