|
|
@ -26,7 +26,7 @@ from transformers import logging as t_logging |
|
|
|
from towhee import register |
|
|
|
from towhee.operator.base import NNOperator, OperatorFlag |
|
|
|
from towhee.types.arg import arg, to_image_color |
|
|
|
from towhee.dc2 import accelerate |
|
|
|
#from towhee.dc2 import accelerate |
|
|
|
|
|
|
|
log = logging.getLogger('run_op') |
|
|
|
warnings.filterwarnings('ignore') |
|
|
@ -76,7 +76,7 @@ class BLIPModelText(nn.Module): |
|
|
|
return text_features |
|
|
|
|
|
|
|
class Model: |
|
|
|
def __init__(self, model_name, modality, device, checkpoint_path): |
|
|
|
def __init__(self, model_name, modality, checkpoint_path, device): |
|
|
|
self.model = create_model(model_name, modality, checkpoint_path, device) |
|
|
|
self.device = device |
|
|
|
|
|
|
@ -143,9 +143,14 @@ class Blip(NNOperator): |
|
|
|
|
|
|
|
def _configs(self): |
|
|
|
config = {} |
|
|
|
config['blip_itm_base'] = {} |
|
|
|
config['blip_itm_base']['name'] = 'Salesforce/blip-itm-base-coco' |
|
|
|
config['blip_itm_base']['image_size'] = 224 |
|
|
|
config['blip_itm_base_coco'] = {} |
|
|
|
config['blip_itm_base_coco']['name'] = 'Salesforce/blip-itm-base-coco' |
|
|
|
config['blip_itm_base_flickr'] = {} |
|
|
|
config['blip_itm_base_flickr']['name'] = 'Salesforce/blip-itm-base-flickr' |
|
|
|
config['blip_itm_large_coco'] = {} |
|
|
|
config['blip_itm_large_coco']['name'] = 'Salesforce/blip-itm-large-coco' |
|
|
|
config['blip_itm_large_flickr'] = {} |
|
|
|
config['blip_itm_large_flickr']['name'] = 'Salesforce/blip-itm-large-flickr' |
|
|
|
return config |
|
|
|
|
|
|
|
@property |
|
|
|