logo
Browse Source

add logging for clip.

Signed-off-by: jinlingxu06 <jinling.xu@zilliz.com>
main
jinlingxu06 2 years ago
parent
commit
8fd4d32e23
  1. 25
      clip.py

25
clip.py

@ -12,18 +12,26 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import sys import sys
import os
from pathlib import Path from pathlib import Path
import torch import torch
from torch import nn from torch import nn
from torchvision import transforms from torchvision import transforms
import logging
import warnings
from towhee.types.image_utils import to_pil from towhee.types.image_utils import to_pil
from towhee.operator.base import NNOperator, OperatorFlag from towhee.operator.base import NNOperator, OperatorFlag
from towhee.types.arg import arg, to_image_color from towhee.types.arg import arg, to_image_color
from towhee import register from towhee import register
from transformers import CLIPTokenizer, CLIPTextModel ,CLIPModel,CLIPProcessor from transformers import CLIPTokenizer, CLIPTextModel ,CLIPModel,CLIPProcessor
from transformers import logging as t_logging
# from towhee.dc2 import accelerate # from towhee.dc2 import accelerate
log = logging.getLogger('run_op')
warnings.filterwarnings('ignore')
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
t_logging.set_verbosity_error()
def create_model(model_name, modality, checkpoint_path, device): def create_model(model_name, modality, checkpoint_path, device):
hf_clip_model = CLIPModel.from_pretrained(model_name) hf_clip_model = CLIPModel.from_pretrained(model_name)
@ -159,13 +167,16 @@ class Clip(NNOperator):
@staticmethod @staticmethod
def supported_model_names(format: str = None): def supported_model_names(format: str = None):
if format == 'pytorch' or format == 'torchscript' or format == 'onnx':
model_list = [
'clip_vit_base_patch16',
'clip_vit_base_patch32',
'clip_vit_large_patch14',
'clip_vit_large_patch14_336'
]
full_list = [
'clip_vit_base_patch16',
'clip_vit_base_patch32',
'clip_vit_large_patch14',
'clip_vit_large_patch14_336'
]
if format == None:
model_list = full_list
elif format == 'pytorch' or format == 'torchscript' or format == 'onnx':
model_list = full_list
else: else:
log.error(f'Invalid format "{format}". Currently supported formats: "pytorch", "torchscript".') log.error(f'Invalid format "{format}". Currently supported formats: "pytorch", "torchscript".')
return model_list return model_list

Loading…
Cancel
Save