# Copyright 2021 Zilliz. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging import numpy from transformers import RealmTokenizer, RealmEmbedder from towhee.operator import NNOperator from towhee import register import warnings warnings.filterwarnings('ignore') logging.getLogger("transformers").setLevel(logging.ERROR) log = logging.getLogger() @register(output_schema=['vec']) class Realm(NNOperator): """ NLP embedding operator that uses the pretrained REALM model gathered by huggingface. Args: model_name (`str`): Which model to use for the embeddings. """ def __init__(self, model_name: str = "google/realm-cc-news-pretrained-embedder") -> None: super().__init__() self.model_name = model_name try: self.model = RealmEmbedder.from_pretrained(model_name) except Exception as e: log.error(f'Fail to load model by name: {self.model_name}') raise e try: self.tokenizer = RealmTokenizer.from_pretrained(model_name) except Exception as e: log.error(f'Fail to load tokenizer by name: {self.model_name}') raise e def __call__(self, txt: str) -> numpy.ndarray: try: inputs = self.tokenizer(txt, return_tensors="pt") except Exception as e: log.error(f'Invalid input for the tokenizer: {self.model_name}') raise e try: outs = self.model(**inputs) except Exception as e: log.error(f'Invalid input for the model: {self.model_name}') raise e try: features = outs.projected_score.squeeze(0) except Exception as e: log.error(f'Fail to extract features by model: {self.model_name}') raise e vec = features.detach().numpy() return vec def get_model_list(): full_list = [ "google/realm-cc-news-pretrained-embedder" ] full_list.sort() return full_list