diff --git a/README.md b/README.md index 2211a84..69e3735 100644 --- a/README.md +++ b/README.md @@ -29,13 +29,13 @@ from towhee.dc2 import pipe, ops, DataCollection img_pipe = ( pipe.input('url') .map('url', 'img', ops.image_decode.cv2_rgb()) - .map('img', 'vec', ops.image_text_embedding.blip(model_name='blip_itm_base', modality='image')) + .map('img', 'vec', ops.image_text_embedding.blip(model_name='blip_itm_base_coco', modality='image')) .output('img', 'vec') ) text_pipe = ( pipe.input('text') - .map('text', 'vec', ops.image_text_embedding.blip(model_name='blip_itm_base', modality='text')) + .map('text', 'vec', ops.image_text_embedding.blip(model_name='blip_itm_base_coco', modality='text')) .output('text', 'vec') ) @@ -62,7 +62,10 @@ Create the operator via the following factory method ​ ***model_name:*** *str* ​ The model name of BLIP. Supported model names: -- blip_itm_base +- blip_itm_base_coco +- blip_itm_large_coco +- blip_itm_base_flickr +- blip_itm_large_flickr ​ ***modality:*** *str* @@ -95,7 +98,7 @@ Save model to local with specified format. ```python from towhee import ops -op = ops.image_text_embedding.blip(model_name='blip_itm_base', modality='image').get_op() +op = ops.image_text_embedding.blip(model_name='blip_itm_base_coco', modality='image').get_op() op.save_model('onnx', 'test.onnx') ```
diff --git a/__init__.py b/__init__.py index 5b6b7ed..77179d4 100644 --- a/__init__.py +++ b/__init__.py @@ -15,5 +15,5 @@ from .blip import Blip -def blip(model_name: str, modality: str, device:str = None, checkpoint_path:str = None): +def blip(model_name: str, modality: str, device:str = 'cpu', checkpoint_path:str = None): return Blip(model_name, modality, device, checkpoint_path) diff --git a/blip.py b/blip.py index 3b0ac85..cc88119 100644 --- a/blip.py +++ b/blip.py @@ -26,7 +26,7 @@ from transformers import logging as t_logging from towhee import register from towhee.operator.base import NNOperator, OperatorFlag from towhee.types.arg import arg, to_image_color -from towhee.dc2 import accelerate +#from towhee.dc2 import accelerate log = logging.getLogger('run_op') warnings.filterwarnings('ignore') @@ -76,7 +76,7 @@ class BLIPModelText(nn.Module): return text_features class Model: - def __init__(self, model_name, modality, device, checkpoint_path): + def __init__(self, model_name, modality, checkpoint_path, device): self.model = create_model(model_name, modality, checkpoint_path, device) self.device = device @@ -143,9 +143,14 @@ class Blip(NNOperator): def _configs(self): config = {} - config['blip_itm_base'] = {} - config['blip_itm_base']['name'] = 'Salesforce/blip-itm-base-coco' - config['blip_itm_base']['image_size'] = 224 + config['blip_itm_base_coco'] = {} + config['blip_itm_base_coco']['name'] = 'Salesforce/blip-itm-base-coco' + config['blip_itm_base_flickr'] = {} + config['blip_itm_base_flickr']['name'] = 'Salesforce/blip-itm-base-flickr' + config['blip_itm_large_coco'] = {} + config['blip_itm_large_coco']['name'] = 'Salesforce/blip-itm-large-coco' + config['blip_itm_large_flickr'] = {} + config['blip_itm_large_flickr']['name'] = 'Salesforce/blip-itm-large-flickr' return config @property