logo
Browse Source

add more models to the operator.

Signed-off-by: wxywb <xy.wang@zilliz.com>
main
wxywb 2 years ago
parent
commit
eefddb5f13
  1. 11
      README.md
  2. 2
      __init__.py
  3. 15
      blip.py

11
README.md

@ -29,13 +29,13 @@ from towhee.dc2 import pipe, ops, DataCollection
img_pipe = (
pipe.input('url')
.map('url', 'img', ops.image_decode.cv2_rgb())
.map('img', 'vec', ops.image_text_embedding.blip(model_name='blip_itm_base', modality='image'))
.map('img', 'vec', ops.image_text_embedding.blip(model_name='blip_itm_base_coco', modality='image'))
.output('img', 'vec')
)
text_pipe = (
pipe.input('text')
.map('text', 'vec', ops.image_text_embedding.blip(model_name='blip_itm_base', modality='text'))
.map('text', 'vec', ops.image_text_embedding.blip(model_name='blip_itm_base_coco', modality='text'))
.output('text', 'vec')
)
@ -62,7 +62,10 @@ Create the operator via the following factory method
***model_name:*** *str*
​ The model name of BLIP. Supported model names:
- blip_itm_base
- blip_itm_base_coco
- blip_itm_large_coco
- blip_itm_base_flickr
- blip_itm_large_flickr
***modality:*** *str*
@ -95,7 +98,7 @@ Save model to local with specified format.
```python
from towhee import ops
op = ops.image_text_embedding.blip(model_name='blip_itm_base', modality='image').get_op()
op = ops.image_text_embedding.blip(model_name='blip_itm_base_coco', modality='image').get_op()
op.save_model('onnx', 'test.onnx')
```
<br />

2
__init__.py

@ -15,5 +15,5 @@
from .blip import Blip
def blip(model_name: str, modality: str, device:str = None, checkpoint_path:str = None):
def blip(model_name: str, modality: str, device:str = 'cpu', checkpoint_path:str = None):
return Blip(model_name, modality, device, checkpoint_path)

15
blip.py

@ -26,7 +26,7 @@ from transformers import logging as t_logging
from towhee import register
from towhee.operator.base import NNOperator, OperatorFlag
from towhee.types.arg import arg, to_image_color
from towhee.dc2 import accelerate
#from towhee.dc2 import accelerate
log = logging.getLogger('run_op')
warnings.filterwarnings('ignore')
@ -76,7 +76,7 @@ class BLIPModelText(nn.Module):
return text_features
class Model:
def __init__(self, model_name, modality, device, checkpoint_path):
def __init__(self, model_name, modality, checkpoint_path, device):
self.model = create_model(model_name, modality, checkpoint_path, device)
self.device = device
@ -143,9 +143,14 @@ class Blip(NNOperator):
def _configs(self):
config = {}
config['blip_itm_base'] = {}
config['blip_itm_base']['name'] = 'Salesforce/blip-itm-base-coco'
config['blip_itm_base']['image_size'] = 224
config['blip_itm_base_coco'] = {}
config['blip_itm_base_coco']['name'] = 'Salesforce/blip-itm-base-coco'
config['blip_itm_base_flickr'] = {}
config['blip_itm_base_flickr']['name'] = 'Salesforce/blip-itm-base-flickr'
config['blip_itm_large_coco'] = {}
config['blip_itm_large_coco']['name'] = 'Salesforce/blip-itm-large-coco'
config['blip_itm_large_flickr'] = {}
config['blip_itm_large_flickr']['name'] = 'Salesforce/blip-itm-large-flickr'
return config
@property

Loading…
Cancel
Save