towhee
/
bert-embedding
copied
4 changed files with 99 additions and 2 deletions
@ -1,3 +1,42 @@ |
|||
# bert-embedding |
|||
# BERT Text Embedding Operator (Pytorch) |
|||
|
|||
This is another test repo |
|||
Authors: Kyle He |
|||
|
|||
## Overview |
|||
|
|||
This operator transforms text into embedding using BERT[1], which stands for |
|||
Bidirectional Encoder Representations from Transformers. |
|||
|
|||
## Interface |
|||
|
|||
```python |
|||
__call__(self, text: str) |
|||
``` |
|||
|
|||
**Args:** |
|||
|
|||
- audio_path: |
|||
- the text to be embedded |
|||
- supported types: str |
|||
|
|||
**Returns:** |
|||
|
|||
The Operator returns a tuple Tuple[('embs', numpy.ndarray)] containing following fields: |
|||
|
|||
- embs: |
|||
- embeddings of the text |
|||
- data type: `numpy.ndarray` |
|||
- shape: 768 |
|||
|
|||
## Requirements |
|||
|
|||
You can get the required python package by [requirements.txt](./requirements.txt). |
|||
|
|||
## How it works |
|||
|
|||
The `towhee/torch-bert` Operator is based on Huggingface[2]. |
|||
|
|||
## Reference |
|||
|
|||
[1]. https://arxiv.org/pdf/1810.04805.pdf |
|||
[2]. https://huggingface.co/docs/transformers |
@ -0,0 +1,13 @@ |
|||
# Copyright 2021 Zilliz. All rights reserved. |
|||
# |
|||
# Licensed under the Apache License, Version 2.0 (the "License"); |
|||
# you may not use this file except in compliance with the License. |
|||
# You may obtain a copy of the License at |
|||
# |
|||
# http://www.apache.org/licenses/LICENSE-2.0 |
|||
# |
|||
# Unless required by applicable law or agreed to in writing, software |
|||
# distributed under the License is distributed on an "AS IS" BASIS, |
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|||
# See the License for the specific language governing permissions and |
|||
# limitations under the License. |
@ -0,0 +1,4 @@ |
|||
torch |
|||
numpy |
|||
bertviz |
|||
transformers |
@ -0,0 +1,41 @@ |
|||
from bertviz.transformers_neuron_view import BertModel, BertConfig |
|||
from transformers import BertTokenizer |
|||
from typing import NamedTuple |
|||
import numpy |
|||
import torch |
|||
|
|||
from towhee.operator import Operator |
|||
|
|||
|
|||
class TorchBert(Operator): |
|||
""" |
|||
Text to embedding using BERT |
|||
""" |
|||
def __init__(self, max_length: int = 256, framework: str = 'pytorch') -> None: |
|||
super().__init__() |
|||
config = BertConfig.from_pretrained("bert-base-cased", output_attentions=True, output_hidden_states=True, |
|||
return_dict=True) |
|||
self.tokenizer = BertTokenizer.from_pretrained("bert-base-cased") |
|||
config.max_position_embeddings = max_length |
|||
self.max_length = max_length |
|||
model = BertModel(config) |
|||
self.model = model.eval() |
|||
|
|||
def __call__(self, text: str) -> NamedTuple('Outputs', [('embs', numpy.ndarray)]): |
|||
inputs = self.tokenizer(text, truncation=True, padding=True, max_length=self.max_length, |
|||
return_tensors='pt') |
|||
f1 = torch.index_select(self.model.embeddings.word_embeddings.weight, 0, |
|||
inputs['input_ids'][0]) # words embeddings |
|||
+ torch.index_select(self.model.embeddings.position_embeddings.weight, 0, |
|||
torch.tensor(range(inputs['input_ids'][0].size(0))).long()) # pos embeddings |
|||
+ torch.index_select(self.model.embeddings.token_type_embeddings.weight, 0, |
|||
inputs['token_type_ids'][0]) # token embeddings |
|||
# single example normalization |
|||
ex1 = f1[0, :] |
|||
ex1_mean = ex1.mean() |
|||
ex1_std = (ex1 - ex1_mean).pow(2).mean() |
|||
norm_embedding = ((ex1 - ex1_mean) / torch.sqrt(ex1_std + 1e-12)) |
|||
norm_embedding_centered = self.model.embeddings.LayerNorm.weight * norm_embedding \ |
|||
+ self.model.embeddings.LayerNorm.bias |
|||
Outputs = NamedTuple('Outputs', [('embs', numpy.ndarray)]) |
|||
return Outputs(norm_embedding_centered.detach().numpy()) |
Loading…
Reference in new issue