From 0286f1790732f77a34ff9100d17e28ba6fe80283 Mon Sep 17 00:00:00 2001 From: wxywb Date: Mon, 13 Feb 2023 11:56:57 +0000 Subject: [PATCH] update the BLIP operator. Signed-off-by: wxywb --- README.md | 32 ++++++++++++++++++++++++++++++++ __init__.py | 4 ++-- blip.py | 11 ++++++----- 3 files changed, 40 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index ad1d3f8..2211a84 100644 --- a/README.md +++ b/README.md @@ -77,6 +77,38 @@ Create the operator via the following factory method An image-text embedding operator takes a [towhee image](link/to/towhee/image/api/doc) or string as input and generate an embedding in ndarray. +***save_model(format='pytorch', path='default')*** + +Save model to local with specified format. + +**Parameters:** + +***format***: *str* + +​ The format of saved model, defaults to 'pytorch'. + +***path***: *str* + +​ The path where model is saved to. By default, it will save model to the operator directory. + + +```python +from towhee import ops + +op = ops.image_text_embedding.blip(model_name='blip_itm_base', modality='image').get_op() +op.save_model('onnx', 'test.onnx') +``` +
+ + + +**Parameters:** + +​ ***data:*** *towhee.types.Image (a sub-class of numpy.ndarray)* or *str* + +​ The data (image or text based on specified modality) to generate embedding. + + **Parameters:** diff --git a/__init__.py b/__init__.py index 3a4024d..5b6b7ed 100644 --- a/__init__.py +++ b/__init__.py @@ -15,5 +15,5 @@ from .blip import Blip -def blip(model_name: str, modality: str): - return Blip(model_name, modality) +def blip(model_name: str, modality: str, device:str = None, checkpoint_path:str = None): + return Blip(model_name, modality, device, checkpoint_path) diff --git a/blip.py b/blip.py index 5eecd61..3b0ac85 100644 --- a/blip.py +++ b/blip.py @@ -76,7 +76,7 @@ class BLIPModelText(nn.Module): return text_features class Model: - def __init__(self, model_name, modality, checkpoint_path, device): + def __init__(self, model_name, modality, device, checkpoint_path): self.model = create_model(model_name, modality, checkpoint_path, device) self.device = device @@ -165,10 +165,11 @@ class Blip(NNOperator): @staticmethod def supported_model_names(format: str = None): - if format == 'pytorch' or format == 'torchscript' or format == 'onnx': - model_list = [ - 'blip_itm_base', - ] + full_list = ['blip_itm_base'] + if format == None: + model_list = full_list + elif format == 'pytorch' or format == 'torchscript' or format == 'onnx': + model_list = full_list else: log.error(f'Invalid format "{format}". Currently supported formats: "pytorch", "torchscript".') return model_list