logo
Browse Source

update the BLIP operator.

Signed-off-by: wxywb <xy.wang@zilliz.com>
main
wxywb 2 years ago
parent
commit
0286f17907
  1. 32
      README.md
  2. 4
      __init__.py
  3. 11
      blip.py

32
README.md

@ -77,6 +77,38 @@ Create the operator via the following factory method
An image-text embedding operator takes a [towhee image](link/to/towhee/image/api/doc) or string as input and generate an embedding in ndarray.
***save_model(format='pytorch', path='default')***
Save model to local with specified format.
**Parameters:**
***format***: *str*
​ The format of saved model, defaults to 'pytorch'.
***path***: *str*
​ The path where model is saved to. By default, it will save model to the operator directory.
```python
from towhee import ops
op = ops.image_text_embedding.blip(model_name='blip_itm_base', modality='image').get_op()
op.save_model('onnx', 'test.onnx')
```
<br />
**Parameters:**
***data:*** *towhee.types.Image (a sub-class of numpy.ndarray)* or *str*
​ The data (image or text based on specified modality) to generate embedding.
**Parameters:**

4
__init__.py

@ -15,5 +15,5 @@
from .blip import Blip
def blip(model_name: str, modality: str):
return Blip(model_name, modality)
def blip(model_name: str, modality: str, device:str = None, checkpoint_path:str = None):
return Blip(model_name, modality, device, checkpoint_path)

11
blip.py

@ -76,7 +76,7 @@ class BLIPModelText(nn.Module):
return text_features
class Model:
def __init__(self, model_name, modality, checkpoint_path, device):
def __init__(self, model_name, modality, device, checkpoint_path):
self.model = create_model(model_name, modality, checkpoint_path, device)
self.device = device
@ -165,10 +165,11 @@ class Blip(NNOperator):
@staticmethod
def supported_model_names(format: str = None):
if format == 'pytorch' or format == 'torchscript' or format == 'onnx':
model_list = [
'blip_itm_base',
]
full_list = ['blip_itm_base']
if format == None:
model_list = full_list
elif format == 'pytorch' or format == 'torchscript' or format == 'onnx':
model_list = full_list
else:
log.error(f'Invalid format "{format}". Currently supported formats: "pytorch", "torchscript".')
return model_list

Loading…
Cancel
Save