towhee
/
bert-embedding
copied
4 changed files with 99 additions and 2 deletions
@ -1,3 +1,42 @@ |
|||||
# bert-embedding |
|
||||
|
# BERT Text Embedding Operator (Pytorch) |
||||
|
|
||||
This is another test repo |
|
||||
|
Authors: Kyle He |
||||
|
|
||||
|
## Overview |
||||
|
|
||||
|
This operator transforms text into embedding using BERT[1], which stands for |
||||
|
Bidirectional Encoder Representations from Transformers. |
||||
|
|
||||
|
## Interface |
||||
|
|
||||
|
```python |
||||
|
__call__(self, text: str) |
||||
|
``` |
||||
|
|
||||
|
**Args:** |
||||
|
|
||||
|
- audio_path: |
||||
|
- the text to be embedded |
||||
|
- supported types: str |
||||
|
|
||||
|
**Returns:** |
||||
|
|
||||
|
The Operator returns a tuple Tuple[('embs', numpy.ndarray)] containing following fields: |
||||
|
|
||||
|
- embs: |
||||
|
- embeddings of the text |
||||
|
- data type: `numpy.ndarray` |
||||
|
- shape: 768 |
||||
|
|
||||
|
## Requirements |
||||
|
|
||||
|
You can get the required python package by [requirements.txt](./requirements.txt). |
||||
|
|
||||
|
## How it works |
||||
|
|
||||
|
The `towhee/torch-bert` Operator is based on Huggingface[2]. |
||||
|
|
||||
|
## Reference |
||||
|
|
||||
|
[1]. https://arxiv.org/pdf/1810.04805.pdf |
||||
|
[2]. https://huggingface.co/docs/transformers |
@ -0,0 +1,13 @@ |
|||||
|
# Copyright 2021 Zilliz. All rights reserved. |
||||
|
# |
||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
|
# you may not use this file except in compliance with the License. |
||||
|
# You may obtain a copy of the License at |
||||
|
# |
||||
|
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
|
# |
||||
|
# Unless required by applicable law or agreed to in writing, software |
||||
|
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
|
# See the License for the specific language governing permissions and |
||||
|
# limitations under the License. |
@ -0,0 +1,4 @@ |
|||||
|
torch |
||||
|
numpy |
||||
|
bertviz |
||||
|
transformers |
@ -0,0 +1,41 @@ |
|||||
|
from bertviz.transformers_neuron_view import BertModel, BertConfig |
||||
|
from transformers import BertTokenizer |
||||
|
from typing import NamedTuple |
||||
|
import numpy |
||||
|
import torch |
||||
|
|
||||
|
from towhee.operator import Operator |
||||
|
|
||||
|
|
||||
|
class TorchBert(Operator): |
||||
|
""" |
||||
|
Text to embedding using BERT |
||||
|
""" |
||||
|
def __init__(self, max_length: int = 256, framework: str = 'pytorch') -> None: |
||||
|
super().__init__() |
||||
|
config = BertConfig.from_pretrained("bert-base-cased", output_attentions=True, output_hidden_states=True, |
||||
|
return_dict=True) |
||||
|
self.tokenizer = BertTokenizer.from_pretrained("bert-base-cased") |
||||
|
config.max_position_embeddings = max_length |
||||
|
self.max_length = max_length |
||||
|
model = BertModel(config) |
||||
|
self.model = model.eval() |
||||
|
|
||||
|
def __call__(self, text: str) -> NamedTuple('Outputs', [('embs', numpy.ndarray)]): |
||||
|
inputs = self.tokenizer(text, truncation=True, padding=True, max_length=self.max_length, |
||||
|
return_tensors='pt') |
||||
|
f1 = torch.index_select(self.model.embeddings.word_embeddings.weight, 0, |
||||
|
inputs['input_ids'][0]) # words embeddings |
||||
|
+ torch.index_select(self.model.embeddings.position_embeddings.weight, 0, |
||||
|
torch.tensor(range(inputs['input_ids'][0].size(0))).long()) # pos embeddings |
||||
|
+ torch.index_select(self.model.embeddings.token_type_embeddings.weight, 0, |
||||
|
inputs['token_type_ids'][0]) # token embeddings |
||||
|
# single example normalization |
||||
|
ex1 = f1[0, :] |
||||
|
ex1_mean = ex1.mean() |
||||
|
ex1_std = (ex1 - ex1_mean).pow(2).mean() |
||||
|
norm_embedding = ((ex1 - ex1_mean) / torch.sqrt(ex1_std + 1e-12)) |
||||
|
norm_embedding_centered = self.model.embeddings.LayerNorm.weight * norm_embedding \ |
||||
|
+ self.model.embeddings.LayerNorm.bias |
||||
|
Outputs = NamedTuple('Outputs', [('embs', numpy.ndarray)]) |
||||
|
return Outputs(norm_embedding_centered.detach().numpy()) |
Loading…
Reference in new issue