import numpy import torch from transformers import LongformerTokenizer, LongformerModel from towhee.operator import NNOperator from towhee import register import warnings import logging warnings.filterwarnings('ignore') logging.getLogger("transformers").setLevel(logging.ERROR) log = logging.getLogger() @register(output_schema=['vec']) class Longformer(NNOperator): """ NLP embedding operator that uses the pretrained longformer model gathered by huggingface. The Longformer model was presented in Longformer: The Long-Document Transformer by Iz Beltagy, Matthew E. Peters, Arman Cohan. Ref: https://huggingface.co/docs/transformers/v4.16.2/en/model_doc/longformer#transformers.LongformerConfig Args: model_name (`str`): Which model to use for the embeddings. global_attention_mask (`torch.Tensor`): Global attention mask depending on tasks pooler_output (`bool`): Whether to pool features """ def __init__( self, model_name: str = 'allenai/longformer-base-4096', global_attention_mask: torch.Tensor = None, pooler_output: bool = False ): super().__init__() self.model_name = model_name self.global_attention_mask = global_attention_mask self.pooler_output = pooler_output try: self.model = LongformerModel.from_pretrained(model_name) except Exception as e: log.error(f'Fail to load model by name: {model_name}') raise e try: self.tokenizer = LongformerTokenizer.from_pretrained(model_name) except Exception as e: log.error(f'Fail to load tokenizer by name: {model_name}') raise e def __call__(self, txt: str) -> numpy.ndarray: try: input_ids = torch.tensor(self.tokenizer.encode(txt)).unsqueeze(0) except Exception as e: log.error(f'Invalid input for the tokenizer: {self.model_name}') raise e try: attention_mask = torch.ones( input_ids.shape, dtype=torch.long, device=input_ids.device ) outs = self.model( input_ids, attention_mask=attention_mask, global_attention_mask=self.global_attention_mask ) except Exception as e: log.error(f'Invalid input for the model: {self.model_name}') raise e try: if self.pooler_output: feature_vector = outs.pooler_output.squeeze(0) else: feature_vector = outs.last_hidden_state.squeeze(0) except Exception as e: log.error(f'Fail to extract features by model: {self.model_name}') raise e vec = feature_vector.detach().numpy() return vec def get_model_list(): full_list = [ "allenai/longformer-base-4096", "allenai/longformer-large-4096", "allenai/longformer-large-4096-finetuned-triviaqa", "allenai/longformer-base-4096-extra.pos.embd.only", "allenai/longformer-large-4096-extra.pos.embd.only", ] full_list.sort() return full_list