# Copyright 2021 Zilliz. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging import numpy from typing import Union, List from pathlib import Path import torch from sentence_transformers import SentenceTransformer from towhee.operator import NNOperator # from towhee.dc2 import accelerate import os import warnings warnings.filterwarnings('ignore') logging.getLogger('sentence_transformers').setLevel(logging.ERROR) log = logging.getLogger('op_sbert') class STransformers(NNOperator): """ Operator using pretrained Sentence Transformers """ def __init__(self, model_name: str = None, device: str = None): self.model_name = model_name if device: self.device = device else: self.device = 'cuda' if torch.cuda.is_available() else 'cpu' if self.model_name: self.model = SentenceTransformer(model_name_or_path=self.model_name, device=self.device) else: log.warning('The operator is initialized without specified model.') pass def __call__(self, txt: Union[List[str], str]): if isinstance(txt, str): sentences = [txt] else: sentences = txt embs = self.model.encode(sentences) # return numpy.ndarray if isinstance(txt, str): embs = embs.squeeze(0) else: embs = list(embs) return embs @staticmethod def supported_model_names(format: str = None): import requests req = requests.get("https://www.sbert.net/_static/html/models_en_sentence_embeddings.html") data = req.text full_list = [] for line in data.split('\r\n'): line = line.replace(' ', '') if line.startswith('"name":'): name = line.split(':')[-1].replace('"', '').replace(',', '') full_list.append(name) full_list.sort() if format is None: model_list = full_list elif format == 'pytorch': to_remove = [] assert set(to_remove).issubset(set(full_list)) model_list = list(set(full_list) - set(to_remove)) else: raise ValueError(f'Invalid or unsupported format "{format}".') log.error(f'Invalid or unsupported format "{format}".') return model_list