codebert
copied
Jael Gu
2 years ago
4 changed files with 308 additions and 1 deletions
@ -1,2 +1,117 @@ |
|||||
# codebert |
|
||||
|
# Code & Text Embedding with CodeBert |
||||
|
|
||||
|
*author: [Jael Gu](https://github.com/jaelgu)* |
||||
|
|
||||
|
<br /> |
||||
|
|
||||
|
## Description |
||||
|
|
||||
|
A code search operator takes a text string of programming language or natural language as an input |
||||
|
and returns an embedding vector in ndarray which captures the input's core semantic elements. |
||||
|
This operator is implemented with pre-trained models from [Huggingface Transformers](https://huggingface.co/docs/transformers). |
||||
|
|
||||
|
<br /> |
||||
|
|
||||
|
## Code Example |
||||
|
|
||||
|
Use the pre-trained model "huggingface/CodeBERTa-small-v1" |
||||
|
to generate text embeddings for given code "" and text description "". |
||||
|
|
||||
|
*Write the pipeline*: |
||||
|
|
||||
|
```python |
||||
|
import towhee |
||||
|
|
||||
|
( |
||||
|
towhee.dc([""]) |
||||
|
.text_embedding.transformers(model_name="distilbert-base-cased") |
||||
|
) |
||||
|
``` |
||||
|
|
||||
|
*Write a same pipeline with explicit inputs/outputs name specifications:* |
||||
|
|
||||
|
```python |
||||
|
import towhee |
||||
|
|
||||
|
( |
||||
|
towhee.dc['text'](["Hello, world."]) |
||||
|
.text_embedding.transformers['text', 'vec'](model_name="distilbert-base-cased") |
||||
|
.show() |
||||
|
) |
||||
|
``` |
||||
|
|
||||
|
<img src="./result.png" width="800px"/> |
||||
|
|
||||
|
<br /> |
||||
|
|
||||
|
## Factory Constructor |
||||
|
|
||||
|
Create the operator via the following factory method: |
||||
|
|
||||
|
***code_search.codebert(model_name="huggingface/CodeBERTa-small-v1")*** |
||||
|
|
||||
|
**Parameters:** |
||||
|
|
||||
|
***model_name***: *str* |
||||
|
|
||||
|
The model name in string. |
||||
|
The default model name is "huggingface/CodeBERTa-small-v1". |
||||
|
|
||||
|
***device***: *str* |
||||
|
|
||||
|
The device to run model inference. |
||||
|
The default value is None, which enables GPU if cuda is available. |
||||
|
|
||||
|
Supported model names: |
||||
|
|
||||
|
|
||||
|
<br /> |
||||
|
|
||||
|
## Interface |
||||
|
|
||||
|
The operator takes a piece of text in string as input. |
||||
|
It loads tokenizer and pre-trained model using model name. |
||||
|
and then return an embedding in ndarray. |
||||
|
|
||||
|
***__call__(txt)*** |
||||
|
|
||||
|
**Parameters:** |
||||
|
|
||||
|
***txt***: *str* |
||||
|
|
||||
|
​ The text string in programming language or natural language. |
||||
|
|
||||
|
|
||||
|
**Returns**: |
||||
|
|
||||
|
*numpy.ndarray* |
||||
|
|
||||
|
​ The text embedding generated by model, in shape of (dim,). |
||||
|
|
||||
|
|
||||
|
***save_model(format="pytorch", path="default")*** |
||||
|
|
||||
|
Save model to local with specified format. |
||||
|
|
||||
|
**Parameters:** |
||||
|
|
||||
|
***format***: *str* |
||||
|
|
||||
|
​ The format of saved model, defaults to "pytorch". |
||||
|
|
||||
|
***format***: *path* |
||||
|
|
||||
|
​ The path where model is saved to. By default, it will save model to the operator directory. |
||||
|
|
||||
|
|
||||
|
***supported_model_names(format=None)*** |
||||
|
|
||||
|
Get a list of all supported model names or supported model names for specified model format. |
||||
|
|
||||
|
**Parameters:** |
||||
|
|
||||
|
***format***: *str* |
||||
|
|
||||
|
​ The model format such as "pytorch", "torchscript". |
||||
|
The default value is None, which will return all supported model names. |
||||
|
|
||||
|
@ -0,0 +1,19 @@ |
|||||
|
# Copyright 2021 Zilliz. All rights reserved. |
||||
|
# |
||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
|
# you may not use this file except in compliance with the License. |
||||
|
# You may obtain a copy of the License at |
||||
|
# |
||||
|
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
|
# |
||||
|
# Unless required by applicable law or agreed to in writing, software |
||||
|
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
|
# See the License for the specific language governing permissions and |
||||
|
# limitations under the License. |
||||
|
|
||||
|
from .codebert import CodeBert |
||||
|
|
||||
|
|
||||
|
def codebert(**kwargs): |
||||
|
return CodeBert(**kwargs) |
@ -0,0 +1,169 @@ |
|||||
|
# Copyright 2021 Zilliz. All rights reserved. |
||||
|
# |
||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
|
# you may not use this file except in compliance with the License. |
||||
|
# You may obtain a copy of the License at |
||||
|
# |
||||
|
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
|
# |
||||
|
# Unless required by applicable law or agreed to in writing, software |
||||
|
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
|
# See the License for the specific language governing permissions and |
||||
|
# limitations under the License. |
||||
|
|
||||
|
import numpy |
||||
|
import os |
||||
|
import torch |
||||
|
from pathlib import Path |
||||
|
from transformers import AutoTokenizer, AutoModel |
||||
|
|
||||
|
from towhee.operator import NNOperator |
||||
|
from towhee import register |
||||
|
|
||||
|
import warnings |
||||
|
import logging |
||||
|
|
||||
|
warnings.filterwarnings('ignore') |
||||
|
logging.getLogger('transformers').setLevel(logging.ERROR) |
||||
|
log = logging.getLogger() |
||||
|
|
||||
|
|
||||
|
@register(output_schema=['vec']) |
||||
|
class AutoTransformers(NNOperator): |
||||
|
""" |
||||
|
An operator generates an embedding for code or natural language text |
||||
|
using a pretrained codebert model gathered by huggingface. |
||||
|
|
||||
|
Args: |
||||
|
model_name (`str`): |
||||
|
Which model to use for the embeddings. |
||||
|
device (`str`): |
||||
|
Device to run model inference. Defaults to None, enable GPU when it is available. |
||||
|
""" |
||||
|
|
||||
|
def __init__(self, model_name: str = 'huggingface/CodeBERTa-small-v1', device: str = None): |
||||
|
super().__init__() |
||||
|
self.model_name = model_name |
||||
|
self.modality = modality |
||||
|
assert modality in ['nlp', 'code'], 'Invalid modality value. Accept only "nlp" or "code".' |
||||
|
|
||||
|
if device is None: |
||||
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
||||
|
self.device = device |
||||
|
|
||||
|
try: |
||||
|
self.model = AutoModel.from_pretrained(model_name).to(self.device) |
||||
|
self.model.eval() |
||||
|
except Exception as e: |
||||
|
model_list = self.supported_model_names() |
||||
|
if model_name not in model_list: |
||||
|
log.error(f'Invalid model name: {model_name}. Supported model names: {model_list}') |
||||
|
else: |
||||
|
log.error(f'Fail to load model by name: {self.model_name}') |
||||
|
raise e |
||||
|
try: |
||||
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
||||
|
except Exception as e: |
||||
|
log.error(f'Fail to load tokenizer by name: {self.model_name}') |
||||
|
raise e |
||||
|
|
||||
|
def __call__(self, txt: str) -> numpy.ndarray: |
||||
|
try: |
||||
|
inputs = self.tokenizer.encode(txt, return_tensors='pt').to(self.device) |
||||
|
except Exception as e: |
||||
|
log.error(f'Invalid input for the tokenizer: {self.model_name}') |
||||
|
raise e |
||||
|
try: |
||||
|
outs = self.model(inputs) |
||||
|
except Exception as e: |
||||
|
log.error(f'Invalid input for the model: {self.model_name}') |
||||
|
raise e |
||||
|
try: |
||||
|
features = outs['pooler_output'].squeeze(0) |
||||
|
except Exception as e: |
||||
|
log.error(f'Fail to extract features by model: {self.model_name}') |
||||
|
raise e |
||||
|
vec = features.cpu().detach().numpy() |
||||
|
return vec |
||||
|
|
||||
|
def save_model(self, format: str = 'pytorch', path: str = 'default'): |
||||
|
if path == 'default': |
||||
|
path = str(Path(__file__).parent) |
||||
|
path = os.path.join(path, 'saved', format) |
||||
|
os.makedirs(path, exist_ok=True) |
||||
|
name = self.model_name.replace('/', '-') |
||||
|
path = os.path.join(path, name) |
||||
|
inputs = self.tokenizer.encode('test', return_tensors='pt').to(self.device) # return a tensor of token ids |
||||
|
if format == 'pytorch': |
||||
|
path = path + '.pt' |
||||
|
torch.save(self.model, path) |
||||
|
elif format == 'torchscript': |
||||
|
path = path + '.pt' |
||||
|
inputs = list(inputs) |
||||
|
try: |
||||
|
try: |
||||
|
jit_model = torch.jit.script(self.model) |
||||
|
except Exception: |
||||
|
jit_model = torch.jit.trace(self.model, inputs, strict=False) |
||||
|
torch.jit.save(jit_model, path) |
||||
|
except Exception as e: |
||||
|
log.error(f'Fail to save as torchscript: {e}.') |
||||
|
raise RuntimeError(f'Fail to save as torchscript: {e}.') |
||||
|
elif format == 'onnx': |
||||
|
path = path + '.onnx' |
||||
|
try: |
||||
|
torch.onnx.export(self.model, |
||||
|
tuple(inputs), |
||||
|
path, |
||||
|
input_names=['input_ids'], # list(inputs.keys()) |
||||
|
output_names=['last_hidden_state', 'pooler_output'], |
||||
|
opset_version=12, |
||||
|
dynamic_axes={ |
||||
|
'input_ids': {0: 'batch_size', 1: 'input_length'}, |
||||
|
'last_hidden_state': {0: 'batch_size'}, |
||||
|
'pooler_output': {0: 'batch_size', 1: 'output_dim'}, |
||||
|
}) |
||||
|
except Exception: |
||||
|
torch.onnx.export(self.model, |
||||
|
tuple(inputs.values()), |
||||
|
path, |
||||
|
input_names=['input_ids'], # list(inputs.keys()) |
||||
|
output_names=['last_hidden_state'], |
||||
|
opset_version=12, |
||||
|
dynamic_axes={ |
||||
|
'input_ids': {0: 'batch_size', 1: 'input_length'}, |
||||
|
'last_hidden_state': {0: 'batch_size'}, |
||||
|
}) |
||||
|
# todo: elif format == 'tensorrt': |
||||
|
else: |
||||
|
log.error(f'Unsupported format "{format}".') |
||||
|
|
||||
|
@staticmethod |
||||
|
def supported_model_names(format: str = None): |
||||
|
full_list = [ |
||||
|
'huggingface/CodeBERTa-small-v1', |
||||
|
'microsoft/codebert-base', |
||||
|
'microsoft/codebert-base-mlm', |
||||
|
'mrm8488/codebert-base-finetuned-stackoverflow-ner' |
||||
|
] |
||||
|
full_list.sort() |
||||
|
if format is None: |
||||
|
model_list = full_list |
||||
|
elif format == 'pytorch': |
||||
|
to_remove = [] |
||||
|
assert set(to_remove).issubset(set(full_list)) |
||||
|
model_list = list(set(full_list) - set(to_remove)) |
||||
|
# todo: elif format == 'torchscript': |
||||
|
# to_remove = [ |
||||
|
# ] |
||||
|
# assert set(to_remove).issubset(set(full_list)) |
||||
|
# model_list = list(set(full_list) - set(to_remove)) |
||||
|
# todo: elif format == 'onnx': |
||||
|
# to_remove = [] |
||||
|
# assert set(to_remove).issubset(set(full_list)) |
||||
|
# model_list = list(set(full_list) - set(to_remove)) |
||||
|
# todo: elif format == 'tensorrt': |
||||
|
else: |
||||
|
log.error(f'Invalid format "{format}". Currently supported formats: "pytorch".') |
||||
|
return model_list |
@ -0,0 +1,4 @@ |
|||||
|
towhee |
||||
|
sentence-transformers |
||||
|
torch>=1.6.0 |
||||
|
transformers>=4.6.0 |
Loading…
Reference in new issue