logo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

73 lines
2.5 KiB

import numpy
import torch
from transformers import LongformerTokenizer, LongformerModel
from towhee.operator import NNOperator
from towhee import register
import warnings
import logging
warnings.filterwarnings('ignore')
logging.getLogger("transformers").setLevel(logging.ERROR)
log = logging.getLogger()
@register(output_schema=['vec'])
class Longformer(NNOperator):
"""
NLP embedding operator that uses the pretrained longformer model gathered by huggingface.
The Longformer model was presented in Longformer: The Long-Document Transformer by Iz Beltagy,
Matthew E. Peters, Arman Cohan.
Ref: https://huggingface.co/docs/transformers/v4.16.2/en/model_doc/longformer#transformers.LongformerConfig
Args:
model_name (`str`):
Which model to use for the embeddings.
"""
def __init__(self, model_name: str = 'allenai/longformer-base-4096') -> None:
super().__init__()
self.model_name = model_name
try:
self.model = LongformerModel.from_pretrained(model_name)
except Exception as e:
log.error(f'Fail to load model by name: {model_name}')
raise e
try:
self.tokenizer = LongformerTokenizer.from_pretrained(model_name)
except Exception as e:
log.error(f'Fail to load tokenizer by name: {model_name}')
raise e
def __call__(self, txt: str) -> numpy.ndarray:
try:
input_ids = torch.tensor(self.tokenizer.encode(txt)).unsqueeze(0)
except Exception as e:
log.error(f'Invalid input for the tokenizer: {self.model_name}')
raise e
try:
attention_mask = None
outs = self.model(input_ids, attention_mask=attention_mask, labels=input_ids, output_hidden_states=True)
except Exception as e:
log.error(f'Invalid input for the model: {self.model_name}')
raise e
try:
feature_vector = outs[1].squeeze()
except Exception as e:
log.error(f'Fail to extract features by model: {self.model_name}')
raise e
vec = feature_vector.detach().numpy()
return vec
def get_model_list():
full_list = [
"allenai/longformer-base-4096",
"allenai/longformer-large-4096",
"allenai/longformer-large-4096-finetuned-triviaqa",
"allenai/longformer-base-4096-extra.pos.embd.only",
"allenai/longformer-large-4096-extra.pos.embd.only",
]
full_list.sort()
return full_list