diff --git a/README.md b/README.md index 85c9cea..8ff6262 100644 --- a/README.md +++ b/README.md @@ -1,9 +1,8 @@ -# Text Embedding with longformer +# Text Embedding with dpr *author: Kyle He* - ## Desription This operator uses Longformer to convert long text to embeddings. @@ -23,26 +22,42 @@ length, making it easy to process documents of thousands of tokens or longer[2]. [2].https://arxiv.org/pdf/2004.05150.pdf +## Code Example + +Use the pretrained model "facebook/dpr-ctx_encoder-single-nq-base" +to generate a text embedding for the sentence "Hello, world.". + + *Write the pipeline*: + ```python -from towhee import ops +from towhee import dc -text_encoder = ops.text_embedding.longformer(model_name="allenai/longformer-base-4096") -text_embedding = text_encoder("Hello, world.") + +dc.stream(["Hello, world."]) + .text_embedding.longformer("allenai/longformer-base-4096") + .show() ``` ## Factory Constructor Create the operator via the following factory method -***ops.text_embedding.longformer(model_name)*** +***text_embedding.dpr(model_name="allenai/longformer-base-4096")*** +**Parameters:** +​ ***model_name***: *str* -## Interface +​ The model name in string. +The default value is "allenai/longformer-base-4096". +You can get the list of supported model names by calling `get_model_list` from [longformer.py](https://towhee.io/text-embedding/longformer/src/branch/main/longformer.py). -A text embedding operator takes a sentence, paragraph, or document in string as an input -and output an embedding vector in ndarray which captures the input's core semantic elements. +## Interface + +The operator takes a text in string as input. +It loads tokenizer and pre-trained model using model name. +and then return text embedding in ndarray. **Parameters:** @@ -52,34 +67,8 @@ and output an embedding vector in ndarray which captures the input's core semant -**Returns**: *numpy.ndarray* +**Returns**: -​ The text embedding extracted by model. +​ *numpy.ndarray* - - -## Code Example - -Use the pretrained model ('allenai/longformer-base-4096') -to generate a text embedding for the sentence "Hello, world.". - - *Write the pipeline in simplified style*: - -```python -import towhee.DataCollection as dc - -dc.glob("Hello, world.") - .text_embedding.longformer('longformer-base-4096') - .show() -``` - -*Write a same pipeline with explicit inputs/outputs name specifications:* - -```python -from towhee import DataCollection as dc - -dc.glob['text']('Hello, world.') - .text_embedding.longformer['text', 'vec']('longformer-base-4096') - .select('vec') - .show() -``` +​ The text embedding extracted by model. \ No newline at end of file diff --git a/__init__.py b/__init__.py index 2832034..22be76d 100644 --- a/__init__.py +++ b/__init__.py @@ -15,5 +15,5 @@ from .longformer import Longformer -def longformer(model_name: str): - return Longformer(model_name) +def longformer(**kwargs): + return Longformer(**kwargs) diff --git a/longformer.py b/longformer.py index 3126249..f0f420a 100644 --- a/longformer.py +++ b/longformer.py @@ -1,14 +1,16 @@ import numpy import torch from transformers import LongformerTokenizer, LongformerModel -import logging from towhee.operator import NNOperator from towhee import register - import warnings +import logging + + warnings.filterwarnings('ignore') +logging.getLogger("transformers").setLevel(logging.ERROR) log = logging.getLogger() @@ -24,7 +26,7 @@ class Longformer(NNOperator): model_name (`str`): Which model to use for the embeddings. """ - def __init__(self, model_name: str) -> None: + def __init__(self, model_name: str = 'allenai/longformer-base-4096') -> None: super().__init__() self.model_name = model_name try: @@ -55,5 +57,17 @@ class Longformer(NNOperator): except Exception as e: log.error(f'Fail to extract features by model: {self.model_name}') raise e - feature_vector = feature_vector.detach().numpy() - return feature_vector + vec = feature_vector.detach().numpy() + return vec + + +def get_model_list(): + full_list = [ + "allenai/longformer-base-4096", + "allenai/longformer-large-4096", + "allenai/longformer-large-4096-finetuned-triviaqa", + "allenai/longformer-base-4096-extra.pos.embd.only", + "allenai/longformer-large-4096-extra.pos.embd.only", + ] + full_list.sort() + return full_list