logo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

35 lines
1.4 KiB

import numpy
from typing import NamedTuple
import torch
from transformers import LongformerTokenizer, LongformerModel
from towhee.operator import NNOperator
import warnings
warnings.filterwarnings('ignore')
class NlpLongformer(NNOperator):
"""
NLP embedding operator that uses the pretrained longformer model gathered by huggingface.
The Longformer model was presented in Longformer: The Long-Document Transformer by Iz Beltagy,
Matthew E. Peters, Arman Cohan.
Ref: https://huggingface.co/docs/transformers/v4.16.2/en/model_doc/longformer#transformers.LongformerConfig
Args:
model_name (`str`):
Which model to use for the embeddings.
"""
def __init__(self, model_name: str) -> None:
self.model = LongformerModel.from_pretrained(model_name)
self.tokenizer = LongformerTokenizer.from_pretrained(model_name)
def __call__(self, txt: str) -> NamedTuple('Outputs', [('feature_vector', numpy.ndarray)]):
input_ids = torch.tensor(self.tokenizer.encode(txt)).unsqueeze(0)
attention_mask = None
outs = self.model(input_ids, attention_mask=attention_mask, output_hidden_states=True)
feature_vector = outs[1].squeeze(0)
Outputs = NamedTuple('Outputs', [('feature_vector', numpy.ndarray)])
return Outputs(feature_vector.detach().numpy())
def get_model(self):
return self.model