logo
Browse Source

add logging for clip.

Signed-off-by: jinlingxu06 <jinling.xu@zilliz.com>
main
jinlingxu06 2 years ago
parent
commit
8fd4d32e23
  1. 25
      clip.py

25
clip.py

@ -12,18 +12,26 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import os
from pathlib import Path
import torch
from torch import nn
from torchvision import transforms
import logging
import warnings
from towhee.types.image_utils import to_pil
from towhee.operator.base import NNOperator, OperatorFlag
from towhee.types.arg import arg, to_image_color
from towhee import register
from transformers import CLIPTokenizer, CLIPTextModel ,CLIPModel,CLIPProcessor
from transformers import logging as t_logging
# from towhee.dc2 import accelerate
log = logging.getLogger('run_op')
warnings.filterwarnings('ignore')
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
t_logging.set_verbosity_error()
def create_model(model_name, modality, checkpoint_path, device):
hf_clip_model = CLIPModel.from_pretrained(model_name)
@ -159,13 +167,16 @@ class Clip(NNOperator):
@staticmethod
def supported_model_names(format: str = None):
if format == 'pytorch' or format == 'torchscript' or format == 'onnx':
model_list = [
'clip_vit_base_patch16',
'clip_vit_base_patch32',
'clip_vit_large_patch14',
'clip_vit_large_patch14_336'
]
full_list = [
'clip_vit_base_patch16',
'clip_vit_base_patch32',
'clip_vit_large_patch14',
'clip_vit_large_patch14_336'
]
if format == None:
model_list = full_list
elif format == 'pytorch' or format == 'torchscript' or format == 'onnx':
model_list = full_list
else:
log.error(f'Invalid format "{format}". Currently supported formats: "pytorch", "torchscript".')
return model_list

Loading…
Cancel
Save