From 67e5aa880f9d51dc50361a32c589855d820070b2 Mon Sep 17 00:00:00 2001 From: oneseer Date: Thu, 20 Jan 2022 20:02:56 +0800 Subject: [PATCH] add text embedding implementation --- README.md | 43 +++++++++++++++++++++++++++++++++++++++++-- __init__.py | 13 +++++++++++++ requirements.txt | 4 ++++ torch_bert.py | 41 +++++++++++++++++++++++++++++++++++++++++ 4 files changed, 99 insertions(+), 2 deletions(-) create mode 100644 __init__.py create mode 100644 requirements.txt create mode 100644 torch_bert.py diff --git a/README.md b/README.md index df313be..d088eef 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,42 @@ -# bert-embedding +# BERT Text Embedding Operator (Pytorch) -This is another test repo \ No newline at end of file +Authors: Kyle He + +## Overview + +This operator transforms text into embedding using BERT[1], which stands for +Bidirectional Encoder Representations from Transformers. + +## Interface + +```python +__call__(self, text: str) +``` + +**Args:** + +- audio_path: + - the text to be embedded + - supported types: str + +**Returns:** + +The Operator returns a tuple Tuple[('embs', numpy.ndarray)] containing following fields: + +- embs: + - embeddings of the text + - data type: `numpy.ndarray` + - shape: 768 + +## Requirements + +You can get the required python package by [requirements.txt](./requirements.txt). + +## How it works + +The `towhee/torch-bert` Operator is based on Huggingface[2]. + +## Reference + +[1]. https://arxiv.org/pdf/1810.04805.pdf +[2]. https://huggingface.co/docs/transformers \ No newline at end of file diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..37f5bd7 --- /dev/null +++ b/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2021 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..e96a9ba --- /dev/null +++ b/requirements.txt @@ -0,0 +1,4 @@ +torch +numpy +bertviz +transformers \ No newline at end of file diff --git a/torch_bert.py b/torch_bert.py new file mode 100644 index 0000000..5d1d3da --- /dev/null +++ b/torch_bert.py @@ -0,0 +1,41 @@ +from bertviz.transformers_neuron_view import BertModel, BertConfig +from transformers import BertTokenizer +from typing import NamedTuple +import numpy +import torch + +from towhee.operator import Operator + + +class TorchBert(Operator): + """ + Text to embedding using BERT + """ + def __init__(self, max_length: int = 256, framework: str = 'pytorch') -> None: + super().__init__() + config = BertConfig.from_pretrained("bert-base-cased", output_attentions=True, output_hidden_states=True, + return_dict=True) + self.tokenizer = BertTokenizer.from_pretrained("bert-base-cased") + config.max_position_embeddings = max_length + self.max_length = max_length + model = BertModel(config) + self.model = model.eval() + + def __call__(self, text: str) -> NamedTuple('Outputs', [('embs', numpy.ndarray)]): + inputs = self.tokenizer(text, truncation=True, padding=True, max_length=self.max_length, + return_tensors='pt') + f1 = torch.index_select(self.model.embeddings.word_embeddings.weight, 0, + inputs['input_ids'][0]) # words embeddings + + torch.index_select(self.model.embeddings.position_embeddings.weight, 0, + torch.tensor(range(inputs['input_ids'][0].size(0))).long()) # pos embeddings + + torch.index_select(self.model.embeddings.token_type_embeddings.weight, 0, + inputs['token_type_ids'][0]) # token embeddings + # single example normalization + ex1 = f1[0, :] + ex1_mean = ex1.mean() + ex1_std = (ex1 - ex1_mean).pow(2).mean() + norm_embedding = ((ex1 - ex1_mean) / torch.sqrt(ex1_std + 1e-12)) + norm_embedding_centered = self.model.embeddings.LayerNorm.weight * norm_embedding \ + + self.model.embeddings.LayerNorm.bias + Outputs = NamedTuple('Outputs', [('embs', numpy.ndarray)]) + return Outputs(norm_embedding_centered.detach().numpy())