# Copyright 2021 Zilliz. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging import numpy from transformers import AutoTokenizer, AutoModel from towhee.operator import NNOperator from towhee import register import warnings warnings.filterwarnings('ignore') log = logging.getLogger() @register(output_schema=['vec']) class AutoTransformers(NNOperator): """ NLP embedding operator that uses the pretrained transformers model gathered by huggingface. Args: model_name (`str`): Which model to use for the embeddings. """ def __init__(self, model_name: str) -> None: super().__init__() self.model_name = model_name try: self.model = AutoModel.from_pretrained(model_name) except Exception as e: log.error(f'Fail to load model by name: {self.model_name}') raise e try: self.tokenizer = AutoTokenizer.from_pretrained(model_name) except Exception as e: log.error(f'Fail to load tokenizer by name: {self.model_name}') raise e def __call__(self, txt: str) -> numpy.ndarray: try: inputs = self.tokenizer(txt, return_tensors="pt") except Exception as e: log.error(f'Invalid input for the tokenizer: {self.model_name}') raise e try: outs = self.model(**inputs) except Exception as e: log.error(f'Invalid input for the model: {self.model_name}') raise e try: features = outs.last_hidden_state.squeeze(0) except Exception as e: log.error(f'Fail to extract features by model: {self.model_name}') raise e feature_vector = features.detach().numpy() return feature_vector