diff --git a/clip.py b/clip.py index beebab4..b3892fe 100644 --- a/clip.py +++ b/clip.py @@ -12,18 +12,26 @@ # See the License for the specific language governing permissions and # limitations under the License. import sys +import os from pathlib import Path import torch from torch import nn from torchvision import transforms +import logging +import warnings from towhee.types.image_utils import to_pil from towhee.operator.base import NNOperator, OperatorFlag from towhee.types.arg import arg, to_image_color from towhee import register from transformers import CLIPTokenizer, CLIPTextModel ,CLIPModel,CLIPProcessor +from transformers import logging as t_logging # from towhee.dc2 import accelerate +log = logging.getLogger('run_op') +warnings.filterwarnings('ignore') +os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' +t_logging.set_verbosity_error() def create_model(model_name, modality, checkpoint_path, device): hf_clip_model = CLIPModel.from_pretrained(model_name) @@ -159,13 +167,16 @@ class Clip(NNOperator): @staticmethod def supported_model_names(format: str = None): - if format == 'pytorch' or format == 'torchscript' or format == 'onnx': - model_list = [ - 'clip_vit_base_patch16', - 'clip_vit_base_patch32', - 'clip_vit_large_patch14', - 'clip_vit_large_patch14_336' - ] + full_list = [ + 'clip_vit_base_patch16', + 'clip_vit_base_patch32', + 'clip_vit_large_patch14', + 'clip_vit_large_patch14_336' + ] + if format == None: + model_list = full_list + elif format == 'pytorch' or format == 'torchscript' or format == 'onnx': + model_list = full_list else: log.error(f'Invalid format "{format}". Currently supported formats: "pytorch", "torchscript".') return model_list