logo
Browse Source

Add files

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 2 years ago
parent
commit
a3bf057a59
  1. 117
      README.md
  2. 19
      __init__.py
  3. 169
      codebert.py
  4. 4
      requirements.txt

117
README.md

@ -1,2 +1,117 @@
# codebert
# Code & Text Embedding with CodeBert
*author: [Jael Gu](https://github.com/jaelgu)*
<br />
## Description
A code search operator takes a text string of programming language or natural language as an input
and returns an embedding vector in ndarray which captures the input's core semantic elements.
This operator is implemented with pre-trained models from [Huggingface Transformers](https://huggingface.co/docs/transformers).
<br />
## Code Example
Use the pre-trained model "huggingface/CodeBERTa-small-v1"
to generate text embeddings for given code "" and text description "".
*Write the pipeline*:
```python
import towhee
(
towhee.dc([""])
.text_embedding.transformers(model_name="distilbert-base-cased")
)
```
*Write a same pipeline with explicit inputs/outputs name specifications:*
```python
import towhee
(
towhee.dc['text'](["Hello, world."])
.text_embedding.transformers['text', 'vec'](model_name="distilbert-base-cased")
.show()
)
```
<img src="./result.png" width="800px"/>
<br />
## Factory Constructor
Create the operator via the following factory method:
***code_search.codebert(model_name="huggingface/CodeBERTa-small-v1")***
**Parameters:**
***model_name***: *str*
The model name in string.
The default model name is "huggingface/CodeBERTa-small-v1".
***device***: *str*
The device to run model inference.
The default value is None, which enables GPU if cuda is available.
Supported model names:
<br />
## Interface
The operator takes a piece of text in string as input.
It loads tokenizer and pre-trained model using model name.
and then return an embedding in ndarray.
***__call__(txt)***
**Parameters:**
***txt***: *str*
​ The text string in programming language or natural language.
**Returns**:
*numpy.ndarray*
​ The text embedding generated by model, in shape of (dim,).
***save_model(format="pytorch", path="default")***
Save model to local with specified format.
**Parameters:**
***format***: *str*
​ The format of saved model, defaults to "pytorch".
***format***: *path*
​ The path where model is saved to. By default, it will save model to the operator directory.
***supported_model_names(format=None)***
Get a list of all supported model names or supported model names for specified model format.
**Parameters:**
***format***: *str*
​ The model format such as "pytorch", "torchscript".
The default value is None, which will return all supported model names.

19
__init__.py

@ -0,0 +1,19 @@
# Copyright 2021 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .codebert import CodeBert
def codebert(**kwargs):
return CodeBert(**kwargs)

169
codebert.py

@ -0,0 +1,169 @@
# Copyright 2021 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy
import os
import torch
from pathlib import Path
from transformers import AutoTokenizer, AutoModel
from towhee.operator import NNOperator
from towhee import register
import warnings
import logging
warnings.filterwarnings('ignore')
logging.getLogger('transformers').setLevel(logging.ERROR)
log = logging.getLogger()
@register(output_schema=['vec'])
class AutoTransformers(NNOperator):
"""
An operator generates an embedding for code or natural language text
using a pretrained codebert model gathered by huggingface.
Args:
model_name (`str`):
Which model to use for the embeddings.
device (`str`):
Device to run model inference. Defaults to None, enable GPU when it is available.
"""
def __init__(self, model_name: str = 'huggingface/CodeBERTa-small-v1', device: str = None):
super().__init__()
self.model_name = model_name
self.modality = modality
assert modality in ['nlp', 'code'], 'Invalid modality value. Accept only "nlp" or "code".'
if device is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.device = device
try:
self.model = AutoModel.from_pretrained(model_name).to(self.device)
self.model.eval()
except Exception as e:
model_list = self.supported_model_names()
if model_name not in model_list:
log.error(f'Invalid model name: {model_name}. Supported model names: {model_list}')
else:
log.error(f'Fail to load model by name: {self.model_name}')
raise e
try:
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
except Exception as e:
log.error(f'Fail to load tokenizer by name: {self.model_name}')
raise e
def __call__(self, txt: str) -> numpy.ndarray:
try:
inputs = self.tokenizer.encode(txt, return_tensors='pt').to(self.device)
except Exception as e:
log.error(f'Invalid input for the tokenizer: {self.model_name}')
raise e
try:
outs = self.model(inputs)
except Exception as e:
log.error(f'Invalid input for the model: {self.model_name}')
raise e
try:
features = outs['pooler_output'].squeeze(0)
except Exception as e:
log.error(f'Fail to extract features by model: {self.model_name}')
raise e
vec = features.cpu().detach().numpy()
return vec
def save_model(self, format: str = 'pytorch', path: str = 'default'):
if path == 'default':
path = str(Path(__file__).parent)
path = os.path.join(path, 'saved', format)
os.makedirs(path, exist_ok=True)
name = self.model_name.replace('/', '-')
path = os.path.join(path, name)
inputs = self.tokenizer.encode('test', return_tensors='pt').to(self.device) # return a tensor of token ids
if format == 'pytorch':
path = path + '.pt'
torch.save(self.model, path)
elif format == 'torchscript':
path = path + '.pt'
inputs = list(inputs)
try:
try:
jit_model = torch.jit.script(self.model)
except Exception:
jit_model = torch.jit.trace(self.model, inputs, strict=False)
torch.jit.save(jit_model, path)
except Exception as e:
log.error(f'Fail to save as torchscript: {e}.')
raise RuntimeError(f'Fail to save as torchscript: {e}.')
elif format == 'onnx':
path = path + '.onnx'
try:
torch.onnx.export(self.model,
tuple(inputs),
path,
input_names=['input_ids'], # list(inputs.keys())
output_names=['last_hidden_state', 'pooler_output'],
opset_version=12,
dynamic_axes={
'input_ids': {0: 'batch_size', 1: 'input_length'},
'last_hidden_state': {0: 'batch_size'},
'pooler_output': {0: 'batch_size', 1: 'output_dim'},
})
except Exception:
torch.onnx.export(self.model,
tuple(inputs.values()),
path,
input_names=['input_ids'], # list(inputs.keys())
output_names=['last_hidden_state'],
opset_version=12,
dynamic_axes={
'input_ids': {0: 'batch_size', 1: 'input_length'},
'last_hidden_state': {0: 'batch_size'},
})
# todo: elif format == 'tensorrt':
else:
log.error(f'Unsupported format "{format}".')
@staticmethod
def supported_model_names(format: str = None):
full_list = [
'huggingface/CodeBERTa-small-v1',
'microsoft/codebert-base',
'microsoft/codebert-base-mlm',
'mrm8488/codebert-base-finetuned-stackoverflow-ner'
]
full_list.sort()
if format is None:
model_list = full_list
elif format == 'pytorch':
to_remove = []
assert set(to_remove).issubset(set(full_list))
model_list = list(set(full_list) - set(to_remove))
# todo: elif format == 'torchscript':
# to_remove = [
# ]
# assert set(to_remove).issubset(set(full_list))
# model_list = list(set(full_list) - set(to_remove))
# todo: elif format == 'onnx':
# to_remove = []
# assert set(to_remove).issubset(set(full_list))
# model_list = list(set(full_list) - set(to_remove))
# todo: elif format == 'tensorrt':
else:
log.error(f'Invalid format "{format}". Currently supported formats: "pytorch".')
return model_list

4
requirements.txt

@ -0,0 +1,4 @@
towhee
sentence-transformers
torch>=1.6.0
transformers>=4.6.0
Loading…
Cancel
Save