codebert
copied
Jael Gu
2 years ago
4 changed files with 308 additions and 1 deletions
@ -1,2 +1,117 @@ |
|||
# codebert |
|||
# Code & Text Embedding with CodeBert |
|||
|
|||
*author: [Jael Gu](https://github.com/jaelgu)* |
|||
|
|||
<br /> |
|||
|
|||
## Description |
|||
|
|||
A code search operator takes a text string of programming language or natural language as an input |
|||
and returns an embedding vector in ndarray which captures the input's core semantic elements. |
|||
This operator is implemented with pre-trained models from [Huggingface Transformers](https://huggingface.co/docs/transformers). |
|||
|
|||
<br /> |
|||
|
|||
## Code Example |
|||
|
|||
Use the pre-trained model "huggingface/CodeBERTa-small-v1" |
|||
to generate text embeddings for given code "" and text description "". |
|||
|
|||
*Write the pipeline*: |
|||
|
|||
```python |
|||
import towhee |
|||
|
|||
( |
|||
towhee.dc([""]) |
|||
.text_embedding.transformers(model_name="distilbert-base-cased") |
|||
) |
|||
``` |
|||
|
|||
*Write a same pipeline with explicit inputs/outputs name specifications:* |
|||
|
|||
```python |
|||
import towhee |
|||
|
|||
( |
|||
towhee.dc['text'](["Hello, world."]) |
|||
.text_embedding.transformers['text', 'vec'](model_name="distilbert-base-cased") |
|||
.show() |
|||
) |
|||
``` |
|||
|
|||
<img src="./result.png" width="800px"/> |
|||
|
|||
<br /> |
|||
|
|||
## Factory Constructor |
|||
|
|||
Create the operator via the following factory method: |
|||
|
|||
***code_search.codebert(model_name="huggingface/CodeBERTa-small-v1")*** |
|||
|
|||
**Parameters:** |
|||
|
|||
***model_name***: *str* |
|||
|
|||
The model name in string. |
|||
The default model name is "huggingface/CodeBERTa-small-v1". |
|||
|
|||
***device***: *str* |
|||
|
|||
The device to run model inference. |
|||
The default value is None, which enables GPU if cuda is available. |
|||
|
|||
Supported model names: |
|||
|
|||
|
|||
<br /> |
|||
|
|||
## Interface |
|||
|
|||
The operator takes a piece of text in string as input. |
|||
It loads tokenizer and pre-trained model using model name. |
|||
and then return an embedding in ndarray. |
|||
|
|||
***__call__(txt)*** |
|||
|
|||
**Parameters:** |
|||
|
|||
***txt***: *str* |
|||
|
|||
The text string in programming language or natural language. |
|||
|
|||
|
|||
**Returns**: |
|||
|
|||
*numpy.ndarray* |
|||
|
|||
The text embedding generated by model, in shape of (dim,). |
|||
|
|||
|
|||
***save_model(format="pytorch", path="default")*** |
|||
|
|||
Save model to local with specified format. |
|||
|
|||
**Parameters:** |
|||
|
|||
***format***: *str* |
|||
|
|||
The format of saved model, defaults to "pytorch". |
|||
|
|||
***format***: *path* |
|||
|
|||
The path where model is saved to. By default, it will save model to the operator directory. |
|||
|
|||
|
|||
***supported_model_names(format=None)*** |
|||
|
|||
Get a list of all supported model names or supported model names for specified model format. |
|||
|
|||
**Parameters:** |
|||
|
|||
***format***: *str* |
|||
|
|||
The model format such as "pytorch", "torchscript". |
|||
The default value is None, which will return all supported model names. |
|||
|
|||
|
@ -0,0 +1,19 @@ |
|||
# Copyright 2021 Zilliz. All rights reserved. |
|||
# |
|||
# Licensed under the Apache License, Version 2.0 (the "License"); |
|||
# you may not use this file except in compliance with the License. |
|||
# You may obtain a copy of the License at |
|||
# |
|||
# http://www.apache.org/licenses/LICENSE-2.0 |
|||
# |
|||
# Unless required by applicable law or agreed to in writing, software |
|||
# distributed under the License is distributed on an "AS IS" BASIS, |
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|||
# See the License for the specific language governing permissions and |
|||
# limitations under the License. |
|||
|
|||
from .codebert import CodeBert |
|||
|
|||
|
|||
def codebert(**kwargs): |
|||
return CodeBert(**kwargs) |
@ -0,0 +1,169 @@ |
|||
# Copyright 2021 Zilliz. All rights reserved. |
|||
# |
|||
# Licensed under the Apache License, Version 2.0 (the "License"); |
|||
# you may not use this file except in compliance with the License. |
|||
# You may obtain a copy of the License at |
|||
# |
|||
# http://www.apache.org/licenses/LICENSE-2.0 |
|||
# |
|||
# Unless required by applicable law or agreed to in writing, software |
|||
# distributed under the License is distributed on an "AS IS" BASIS, |
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|||
# See the License for the specific language governing permissions and |
|||
# limitations under the License. |
|||
|
|||
import numpy |
|||
import os |
|||
import torch |
|||
from pathlib import Path |
|||
from transformers import AutoTokenizer, AutoModel |
|||
|
|||
from towhee.operator import NNOperator |
|||
from towhee import register |
|||
|
|||
import warnings |
|||
import logging |
|||
|
|||
warnings.filterwarnings('ignore') |
|||
logging.getLogger('transformers').setLevel(logging.ERROR) |
|||
log = logging.getLogger() |
|||
|
|||
|
|||
@register(output_schema=['vec']) |
|||
class AutoTransformers(NNOperator): |
|||
""" |
|||
An operator generates an embedding for code or natural language text |
|||
using a pretrained codebert model gathered by huggingface. |
|||
|
|||
Args: |
|||
model_name (`str`): |
|||
Which model to use for the embeddings. |
|||
device (`str`): |
|||
Device to run model inference. Defaults to None, enable GPU when it is available. |
|||
""" |
|||
|
|||
def __init__(self, model_name: str = 'huggingface/CodeBERTa-small-v1', device: str = None): |
|||
super().__init__() |
|||
self.model_name = model_name |
|||
self.modality = modality |
|||
assert modality in ['nlp', 'code'], 'Invalid modality value. Accept only "nlp" or "code".' |
|||
|
|||
if device is None: |
|||
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|||
self.device = device |
|||
|
|||
try: |
|||
self.model = AutoModel.from_pretrained(model_name).to(self.device) |
|||
self.model.eval() |
|||
except Exception as e: |
|||
model_list = self.supported_model_names() |
|||
if model_name not in model_list: |
|||
log.error(f'Invalid model name: {model_name}. Supported model names: {model_list}') |
|||
else: |
|||
log.error(f'Fail to load model by name: {self.model_name}') |
|||
raise e |
|||
try: |
|||
self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
|||
except Exception as e: |
|||
log.error(f'Fail to load tokenizer by name: {self.model_name}') |
|||
raise e |
|||
|
|||
def __call__(self, txt: str) -> numpy.ndarray: |
|||
try: |
|||
inputs = self.tokenizer.encode(txt, return_tensors='pt').to(self.device) |
|||
except Exception as e: |
|||
log.error(f'Invalid input for the tokenizer: {self.model_name}') |
|||
raise e |
|||
try: |
|||
outs = self.model(inputs) |
|||
except Exception as e: |
|||
log.error(f'Invalid input for the model: {self.model_name}') |
|||
raise e |
|||
try: |
|||
features = outs['pooler_output'].squeeze(0) |
|||
except Exception as e: |
|||
log.error(f'Fail to extract features by model: {self.model_name}') |
|||
raise e |
|||
vec = features.cpu().detach().numpy() |
|||
return vec |
|||
|
|||
def save_model(self, format: str = 'pytorch', path: str = 'default'): |
|||
if path == 'default': |
|||
path = str(Path(__file__).parent) |
|||
path = os.path.join(path, 'saved', format) |
|||
os.makedirs(path, exist_ok=True) |
|||
name = self.model_name.replace('/', '-') |
|||
path = os.path.join(path, name) |
|||
inputs = self.tokenizer.encode('test', return_tensors='pt').to(self.device) # return a tensor of token ids |
|||
if format == 'pytorch': |
|||
path = path + '.pt' |
|||
torch.save(self.model, path) |
|||
elif format == 'torchscript': |
|||
path = path + '.pt' |
|||
inputs = list(inputs) |
|||
try: |
|||
try: |
|||
jit_model = torch.jit.script(self.model) |
|||
except Exception: |
|||
jit_model = torch.jit.trace(self.model, inputs, strict=False) |
|||
torch.jit.save(jit_model, path) |
|||
except Exception as e: |
|||
log.error(f'Fail to save as torchscript: {e}.') |
|||
raise RuntimeError(f'Fail to save as torchscript: {e}.') |
|||
elif format == 'onnx': |
|||
path = path + '.onnx' |
|||
try: |
|||
torch.onnx.export(self.model, |
|||
tuple(inputs), |
|||
path, |
|||
input_names=['input_ids'], # list(inputs.keys()) |
|||
output_names=['last_hidden_state', 'pooler_output'], |
|||
opset_version=12, |
|||
dynamic_axes={ |
|||
'input_ids': {0: 'batch_size', 1: 'input_length'}, |
|||
'last_hidden_state': {0: 'batch_size'}, |
|||
'pooler_output': {0: 'batch_size', 1: 'output_dim'}, |
|||
}) |
|||
except Exception: |
|||
torch.onnx.export(self.model, |
|||
tuple(inputs.values()), |
|||
path, |
|||
input_names=['input_ids'], # list(inputs.keys()) |
|||
output_names=['last_hidden_state'], |
|||
opset_version=12, |
|||
dynamic_axes={ |
|||
'input_ids': {0: 'batch_size', 1: 'input_length'}, |
|||
'last_hidden_state': {0: 'batch_size'}, |
|||
}) |
|||
# todo: elif format == 'tensorrt': |
|||
else: |
|||
log.error(f'Unsupported format "{format}".') |
|||
|
|||
@staticmethod |
|||
def supported_model_names(format: str = None): |
|||
full_list = [ |
|||
'huggingface/CodeBERTa-small-v1', |
|||
'microsoft/codebert-base', |
|||
'microsoft/codebert-base-mlm', |
|||
'mrm8488/codebert-base-finetuned-stackoverflow-ner' |
|||
] |
|||
full_list.sort() |
|||
if format is None: |
|||
model_list = full_list |
|||
elif format == 'pytorch': |
|||
to_remove = [] |
|||
assert set(to_remove).issubset(set(full_list)) |
|||
model_list = list(set(full_list) - set(to_remove)) |
|||
# todo: elif format == 'torchscript': |
|||
# to_remove = [ |
|||
# ] |
|||
# assert set(to_remove).issubset(set(full_list)) |
|||
# model_list = list(set(full_list) - set(to_remove)) |
|||
# todo: elif format == 'onnx': |
|||
# to_remove = [] |
|||
# assert set(to_remove).issubset(set(full_list)) |
|||
# model_list = list(set(full_list) - set(to_remove)) |
|||
# todo: elif format == 'tensorrt': |
|||
else: |
|||
log.error(f'Invalid format "{format}". Currently supported formats: "pytorch".') |
|||
return model_list |
@ -0,0 +1,4 @@ |
|||
towhee |
|||
sentence-transformers |
|||
torch>=1.6.0 |
|||
transformers>=4.6.0 |
Loading…
Reference in new issue