logo
Browse Source

Refactor

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 2 years ago
parent
commit
1389d9923c
  1. 74
      README.md
  2. 6
      __init__.py
  3. 134
      auto_transformers.py

74
README.md

@ -6,69 +6,75 @@
## Desription
A text embedding operator implemented with pretrained models from [Huggingface Transformers](https://huggingface.co/docs/transformers).
A text embedding operator takes a sentence, paragraph, or document in string as an input
and output an embedding vector in ndarray which captures the input's core semantic elements.
This operator is implemented with pretrained models from [Huggingface Transformers](https://huggingface.co/docs/transformers).
## Code Example
Use the pretrained model 'distilbert-base-cased'
to generate a text embedding for the sentence "Hello, world.".
*Write the pipeline in simplified style*:
```python
from towhee import ops
from towhee import dc
text_encoder = ops.text_embedding.transformers(model_name="bert-base-cased")
text_embedding = text_encoder("Hello, world.")
dc.stream(["Hello, world."])
.text_embedding.transformers('distilbert-base-cased')
.show()
```
## Factory Constructor
*Write a same pipeline with explicit inputs/outputs name specifications:*
Create the operator via the following factory method
```python
from towhee import dc
***ops.text_embedding.transformers(model_name)***
dc.stream['txt'](["Hello, world."])
.text_embedding.transformers['txt', 'vec']('distilbert-base-cased')
.select('txt', 'vec')
.show()
```
## Interface
## Factory Constructor
A text embedding operator takes a sentence, paragraph, or document in string as an input
and output an embedding vector in ndarray which captures the input's core semantic elements.
Create the operator via the following factory method
***text_embedding.transformers(model_name="bert-base-uncased")***
**Parameters:**
***text***: *str*
​ The text in string.
***model_name***: *str*
​ The model name in string.
You can get the list of supported model names by calling `model_list` of the operator:
```python
from towhee import ops
**Returns**: *numpy.ndarray*
​ The text embedding extracted by model.
ops.text_embedding.transformers.model_list()
```
## Interface
## Code Example
The operator takes a text in string as input.
It loads tokenizer and pre-trained model using model name.
Text embeddings are returned in ndarray.
Use the pretrained Bert-Base-Cased model ('bert-base-cased')
to generate a text embedding for the sentence "Hello, world.".
*Write the pipeline in simplified style*:
**Parameters:**
```python
import towhee.DataCollection as dc
***text***: *str*
dc.glob("Hello, world.")
.text_embedding.transformers('bert-base-cased')
.show()
```
​ The text in string.
*Write a same pipeline with explicit inputs/outputs name specifications:*
```python
from towhee import DataCollection as dc
dc.glob['text']('Hello, world.')
.text_embedding.transformers['text', 'vec']('bert-base-cased')
.select('vec')
.show()
```
**Returns**: *numpy.ndarray*
​ The text embedding extracted by model.

6
__init__.py

@ -12,8 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .auto_transformers import AutoTransformers
from .auto_transformers import AutoTransformers, get_model_list
def transformers(model_name: str):
return AutoTransformers(model_name)
def model_list():
return get_model_list()

134
auto_transformers.py

@ -23,6 +23,7 @@ from towhee import register
import warnings
warnings.filterwarnings('ignore')
logging.getLogger("transformers").setLevel(logging.ERROR)
log = logging.getLogger()
@ -35,13 +36,17 @@ class AutoTransformers(NNOperator):
Which model to use for the embeddings.
"""
def __init__(self, model_name: str) -> None:
def __init__(self, model_name: str = "bert-base-uncased") -> None:
super().__init__()
self.model_name = model_name
try:
self.model = AutoModel.from_pretrained(model_name)
except Exception as e:
log.error(f'Fail to load model by name: {self.model_name}')
model_list = get_model_list()
if model_name not in model_list:
log.error(f"Invalid model name: {model_name}. Supported model names: {model_list}")
else:
log.error(f"Fail to load model by name: {self.model_name}")
raise e
try:
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
@ -65,6 +70,127 @@ class AutoTransformers(NNOperator):
except Exception as e:
log.error(f'Fail to extract features by model: {self.model_name}')
raise e
feature_vector = features.detach().numpy()
return feature_vector
vec = features.detach().numpy()
return vec
def get_model_list():
full_list = [
"bert-large-uncased",
"bert-base-cased",
"bert-large-cased",
"bert-base-multilingual-uncased",
"bert-base-multilingual-cased",
"bert-base-chinese",
"bert-base-german-cased",
"bert-large-uncased-whole-word-masking",
"bert-large-cased-whole-word-masking",
"bert-large-uncased-whole-word-masking-finetuned-squad",
"bert-large-cased-whole-word-masking-finetuned-squad",
"bert-base-cased-finetuned-mrpc",
"bert-base-german-dbmdz-cased",
"bert-base-german-dbmdz-uncased",
"cl-tohoku/bert-base-japanese-whole-word-masking",
"cl-tohoku/bert-base-japanese-char",
"cl-tohoku/bert-base-japanese-char-whole-word-masking",
"TurkuNLP/bert-base-finnish-cased-v1",
"TurkuNLP/bert-base-finnish-uncased-v1",
"wietsedv/bert-base-dutch-cased",
"google/bigbird-roberta-base",
"google/bigbird-roberta-large",
"google/bigbird-base-trivia-itc",
"albert-base-v1",
"albert-large-v1",
"albert-xlarge-v1",
"albert-xxlarge-v1",
"albert-base-v2",
"albert-large-v2",
"albert-xlarge-v2",
"albert-xxlarge-v2",
"facebook/bart-large",
"google/bert_for_seq_generation_L-24_bbc_encoder",
"google/bigbird-pegasus-large-arxiv",
"google/bigbird-pegasus-large-pubmed",
"google/bigbird-pegasus-large-bigpatent",
"google/canine-s",
"google/canine-c",
"YituTech/conv-bert-base",
"YituTech/conv-bert-medium-small",
"YituTech/conv-bert-small",
"ctrl",
"microsoft/deberta-base",
"microsoft/deberta-large",
"microsoft/deberta-xlarge",
"microsoft/deberta-base-mnli",
"microsoft/deberta-large-mnli",
"microsoft/deberta-xlarge-mnli",
"distilbert-base-uncased",
"distilbert-base-uncased-distilled-squad",
"distilbert-base-cased",
"distilbert-base-cased-distilled-squad",
"distilbert-base-german-cased",
"distilbert-base-multilingual-cased",
"distilbert-base-uncased-finetuned-sst-2-english",
"google/electra-small-generator",
"google/electra-base-generator",
"google/electra-large-generator",
"google/electra-small-discriminator",
"google/electra-base-discriminator",
"google/electra-large-discriminator",
"google/fnet-base",
"google/fnet-large",
"facebook/wmt19-ru-en",
"funnel-transformer/small",
"funnel-transformer/small-base",
"funnel-transformer/medium",
"funnel-transformer/medium-base",
"funnel-transformer/intermediate",
"funnel-transformer/intermediate-base",
"funnel-transformer/large",
"funnel-transformer/large-base",
"funnel-transformer/xlarge-base",
"funnel-transformer/xlarge",
"gpt2",
"gpt2-medium",
"gpt2-large",
"gpt2-xl",
"distilgpt2",
"EleutherAI/gpt-neo-1.3B",
"EleutherAI/gpt-j-6B",
"kssteven/ibert-roberta-base",
"allenai/led-base-16384",
"google/mobilebert-uncased",
"microsoft/mpnet-base",
"uw-madison/nystromformer-512",
"openai-gpt",
"google/reformer-crime-and-punishment",
"tau/splinter-base",
"tau/splinter-base-qass",
"tau/splinter-large",
"tau/splinter-large-qass",
"squeezebert/squeezebert-uncased",
"squeezebert/squeezebert-mnli",
"squeezebert/squeezebert-mnli-headless",
"transfo-xl-wt103",
"xlm-mlm-en-2048",
"xlm-mlm-ende-1024",
"xlm-mlm-enfr-1024",
"xlm-mlm-enro-1024",
"xlm-mlm-tlm-xnli15-1024",
"xlm-mlm-xnli15-1024",
"xlm-clm-enfr-1024",
"xlm-clm-ende-1024",
"xlm-mlm-17-1280",
"xlm-mlm-100-1280",
"xlm-roberta-base",
"xlm-roberta-large",
"xlm-roberta-large-finetuned-conll02-dutch",
"xlm-roberta-large-finetuned-conll02-spanish",
"xlm-roberta-large-finetuned-conll03-english",
"xlm-roberta-large-finetuned-conll03-german",
"xlnet-base-cased",
"xlnet-large-cased",
"uw-madison/yoso-4096",
]
full_list.sort()
return full_list

Loading…
Cancel
Save