unixcoder
copied
Jael Gu
2 years ago
4 changed files with 314 additions and 1 deletions
@ -1,2 +1,119 @@ |
|||||
# unixcoder |
|
||||
|
# Code & Text Embedding with UniXcoder |
||||
|
|
||||
|
*author: [Jael Gu](https://github.com/jaelgu)* |
||||
|
|
||||
|
<br /> |
||||
|
|
||||
|
## Description |
||||
|
|
||||
|
A code search operator takes a text string of programming language or natural language as an input |
||||
|
and returns an embedding vector in ndarray which captures the input's core semantic elements. |
||||
|
This operator is implemented with pre-trained [UniXcoder](https://arxiv.org/pdf/2203.03850.pdf) models |
||||
|
from [Huggingface Transformers](https://huggingface.co/docs/transformers). |
||||
|
|
||||
|
<br /> |
||||
|
|
||||
|
## Code Example |
||||
|
|
||||
|
Use the pre-trained model "microsoft/unixcoder-base" |
||||
|
to generate text embeddings for given |
||||
|
text description "return max value" and code "def max(a,b): if a>b: return a else return b". |
||||
|
|
||||
|
*Write the pipeline*: |
||||
|
|
||||
|
```python |
||||
|
import towhee |
||||
|
|
||||
|
( |
||||
|
towhee.dc(['find max value', 'def max(a,b): if a>b: return a else return b']) |
||||
|
.code_search.unixcoder(model_name='microsoft/unixcoder-base') |
||||
|
) |
||||
|
``` |
||||
|
|
||||
|
*Write a same pipeline with explicit inputs/outputs name specifications:* |
||||
|
|
||||
|
```python |
||||
|
import towhee |
||||
|
|
||||
|
( |
||||
|
towhee.dc['text'](['return max value', 'def max(a,b): if a>b: return a else return b']) |
||||
|
.code_search.unixcoder['text', 'embedding']() |
||||
|
.show() |
||||
|
) |
||||
|
``` |
||||
|
|
||||
|
<img src="./result.png" width="800px"/> |
||||
|
|
||||
|
<br /> |
||||
|
|
||||
|
## Factory Constructor |
||||
|
|
||||
|
Create the operator via the following factory method: |
||||
|
|
||||
|
***code_search.unixcoder(model_name="microsoft/unixcoder-base")*** |
||||
|
|
||||
|
**Parameters:** |
||||
|
|
||||
|
***model_name***: *str* |
||||
|
|
||||
|
The model name in string. |
||||
|
The default model name is "microsoft/unixcoder-base". |
||||
|
|
||||
|
***device***: *str* |
||||
|
|
||||
|
The device to run model inference. |
||||
|
The default value is None, which enables GPU if cuda is available. |
||||
|
|
||||
|
Supported model names: |
||||
|
- microsoft/unixcoder-base |
||||
|
|
||||
|
<br /> |
||||
|
|
||||
|
## Interface |
||||
|
|
||||
|
The operator takes a piece of text in string as input. |
||||
|
It loads tokenizer and pre-trained model using model name. |
||||
|
and then return an embedding in ndarray. |
||||
|
|
||||
|
***__call__(txt)*** |
||||
|
|
||||
|
**Parameters:** |
||||
|
|
||||
|
***txt***: *str* |
||||
|
|
||||
|
​ The text string in programming language or natural language. |
||||
|
|
||||
|
|
||||
|
**Returns**: |
||||
|
|
||||
|
*numpy.ndarray* |
||||
|
|
||||
|
​ The text embedding generated by model, in shape of (dim,). |
||||
|
|
||||
|
|
||||
|
***save_model(format="pytorch", path="default")*** |
||||
|
|
||||
|
Save model to local with specified format. |
||||
|
|
||||
|
**Parameters:** |
||||
|
|
||||
|
***format***: *str* |
||||
|
|
||||
|
​ The format of saved model, defaults to "pytorch". |
||||
|
|
||||
|
***format***: *path* |
||||
|
|
||||
|
​ The path where model is saved to. By default, it will save model to the operator directory. |
||||
|
|
||||
|
|
||||
|
***supported_model_names(format=None)*** |
||||
|
|
||||
|
Get a list of all supported model names or supported model names for specified model format. |
||||
|
|
||||
|
**Parameters:** |
||||
|
|
||||
|
***format***: *str* |
||||
|
|
||||
|
​ The model format such as "pytorch", "torchscript". |
||||
|
The default value is None, which will return all supported model names. |
||||
|
|
||||
|
@ -0,0 +1,19 @@ |
|||||
|
# Copyright 2021 Zilliz. All rights reserved. |
||||
|
# |
||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
|
# you may not use this file except in compliance with the License. |
||||
|
# You may obtain a copy of the License at |
||||
|
# |
||||
|
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
|
# |
||||
|
# Unless required by applicable law or agreed to in writing, software |
||||
|
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
|
# See the License for the specific language governing permissions and |
||||
|
# limitations under the License. |
||||
|
|
||||
|
from .unixcoder import Unixcoder |
||||
|
|
||||
|
|
||||
|
def unixcoder(**kwargs): |
||||
|
return Unixcoder(**kwargs) |
@ -0,0 +1,4 @@ |
|||||
|
towhee |
||||
|
sentence-transformers |
||||
|
torch>=1.6.0 |
||||
|
transformers>=4.6.0 |
@ -0,0 +1,173 @@ |
|||||
|
# Copyright 2021 Zilliz. All rights reserved. |
||||
|
# |
||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
|
# you may not use this file except in compliance with the License. |
||||
|
# You may obtain a copy of the License at |
||||
|
# |
||||
|
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
|
# |
||||
|
# Unless required by applicable law or agreed to in writing, software |
||||
|
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
|
# See the License for the specific language governing permissions and |
||||
|
# limitations under the License. |
||||
|
|
||||
|
import numpy |
||||
|
import os |
||||
|
import torch |
||||
|
from pathlib import Path |
||||
|
from transformers import RobertaTokenizer, RobertaModel, RobertaConfig |
||||
|
|
||||
|
from towhee.operator import NNOperator |
||||
|
from towhee import register |
||||
|
|
||||
|
import warnings |
||||
|
import logging |
||||
|
|
||||
|
warnings.filterwarnings('ignore') |
||||
|
logging.getLogger('transformers').setLevel(logging.ERROR) |
||||
|
log = logging.getLogger() |
||||
|
|
||||
|
|
||||
|
@register(output_schema=['vec']) |
||||
|
class Unixcoder(NNOperator): |
||||
|
""" |
||||
|
An operator generates an embedding for code or natural language text |
||||
|
using a pretrained UniXcoder model gathered by huggingface. |
||||
|
|
||||
|
Args: |
||||
|
model_name (`str`): |
||||
|
Which model to use for the embeddings. |
||||
|
device (`str`): |
||||
|
Device to run model inference. Defaults to None, enable GPU when it is available. |
||||
|
""" |
||||
|
|
||||
|
def __init__(self, model_name: str = 'microsoft/unixcoder-base', device: str = None): |
||||
|
super().__init__() |
||||
|
self.model_name = model_name |
||||
|
|
||||
|
if device is None: |
||||
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
||||
|
self.device = device |
||||
|
|
||||
|
try: |
||||
|
self.model = RobertaModel.from_pretrained(model_name).to(self.device) |
||||
|
self.model.eval() |
||||
|
except Exception as e: |
||||
|
model_list = self.supported_model_names() |
||||
|
if model_name not in model_list: |
||||
|
log.error(f'Invalid model name: {model_name}. Supported model names: {model_list}') |
||||
|
else: |
||||
|
log.error(f'Fail to load model by name: {self.model_name}') |
||||
|
raise e |
||||
|
try: |
||||
|
self.tokenizer = RobertaTokenizer.from_pretrained(model_name) |
||||
|
except Exception as e: |
||||
|
log.error(f'Fail to load tokenizer by name: {self.model_name}') |
||||
|
raise e |
||||
|
try: |
||||
|
self.configs = RobertaConfig.from_pretrained(model_name) |
||||
|
except Exception as e: |
||||
|
log.error(f'Fail to load configs by name: {self.model_name}') |
||||
|
raise e |
||||
|
|
||||
|
def __call__(self, txt: str) -> numpy.ndarray: |
||||
|
try: |
||||
|
tokens = self.tokenizer.tokenize(txt) |
||||
|
tokens = [self.tokenizer.cls_token, '<encoder-only>', self.tokenizer.sep_token] + tokens + \ |
||||
|
[self.tokenizer.sep_token] |
||||
|
tokens_ids = self.tokenizer.convert_tokens_to_ids(tokens) |
||||
|
inputs = torch.tensor(tokens_ids).unsqueeze(0).to(self.device) |
||||
|
except Exception as e: |
||||
|
log.error(f'Invalid input for the tokenizer: {self.model_name}') |
||||
|
raise e |
||||
|
try: |
||||
|
outs = self.model(inputs) |
||||
|
except Exception as e: |
||||
|
log.error(f'Invalid input for the model: {self.model_name}') |
||||
|
raise e |
||||
|
try: |
||||
|
features = outs['pooler_output'].squeeze(0) |
||||
|
except Exception as e: |
||||
|
log.error(f'Fail to extract features by model: {self.model_name}') |
||||
|
raise e |
||||
|
vec = features.cpu().detach().numpy() |
||||
|
return vec |
||||
|
|
||||
|
def save_model(self, format: str = 'pytorch', path: str = 'default'): |
||||
|
if path == 'default': |
||||
|
path = str(Path(__file__).parent) |
||||
|
path = os.path.join(path, 'saved', format) |
||||
|
os.makedirs(path, exist_ok=True) |
||||
|
name = self.model_name.replace('/', '-') |
||||
|
path = os.path.join(path, name) |
||||
|
inputs = self.tokenizer.encode('test', return_tensors='pt').to(self.device) # return a tensor of token ids |
||||
|
if format == 'pytorch': |
||||
|
path = path + '.pt' |
||||
|
torch.save(self.model, path) |
||||
|
elif format == 'torchscript': |
||||
|
path = path + '.pt' |
||||
|
inputs = list(inputs) |
||||
|
try: |
||||
|
try: |
||||
|
jit_model = torch.jit.script(self.model) |
||||
|
except Exception: |
||||
|
jit_model = torch.jit.trace(self.model, inputs, strict=False) |
||||
|
torch.jit.save(jit_model, path) |
||||
|
except Exception as e: |
||||
|
log.error(f'Fail to save as torchscript: {e}.') |
||||
|
raise RuntimeError(f'Fail to save as torchscript: {e}.') |
||||
|
elif format == 'onnx': |
||||
|
path = path + '.onnx' |
||||
|
try: |
||||
|
torch.onnx.export(self.model, |
||||
|
tuple(inputs), |
||||
|
path, |
||||
|
input_names=['input_ids'], # list(inputs.keys()) |
||||
|
output_names=['last_hidden_state', 'pooler_output'], |
||||
|
opset_version=12, |
||||
|
dynamic_axes={ |
||||
|
'input_ids': {0: 'batch_size', 1: 'input_length'}, |
||||
|
'last_hidden_state': {0: 'batch_size'}, |
||||
|
'pooler_output': {0: 'batch_size', 1: 'output_dim'}, |
||||
|
}) |
||||
|
except Exception: |
||||
|
torch.onnx.export(self.model, |
||||
|
tuple(inputs.values()), |
||||
|
path, |
||||
|
input_names=['input_ids'], # list(inputs.keys()) |
||||
|
output_names=['last_hidden_state'], |
||||
|
opset_version=12, |
||||
|
dynamic_axes={ |
||||
|
'input_ids': {0: 'batch_size', 1: 'input_length'}, |
||||
|
'last_hidden_state': {0: 'batch_size'}, |
||||
|
}) |
||||
|
# todo: elif format == 'tensorrt': |
||||
|
else: |
||||
|
log.error(f'Unsupported format "{format}".') |
||||
|
|
||||
|
@staticmethod |
||||
|
def supported_model_names(format: str = None): |
||||
|
full_list = [ |
||||
|
'microsoft/unixcoder-base' |
||||
|
] |
||||
|
full_list.sort() |
||||
|
if format is None: |
||||
|
model_list = full_list |
||||
|
elif format == 'pytorch': |
||||
|
to_remove = [] |
||||
|
assert set(to_remove).issubset(set(full_list)) |
||||
|
model_list = list(set(full_list) - set(to_remove)) |
||||
|
# todo: elif format == 'torchscript': |
||||
|
# to_remove = [ |
||||
|
# ] |
||||
|
# assert set(to_remove).issubset(set(full_list)) |
||||
|
# model_list = list(set(full_list) - set(to_remove)) |
||||
|
# todo: elif format == 'onnx': |
||||
|
# to_remove = [] |
||||
|
# assert set(to_remove).issubset(set(full_list)) |
||||
|
# model_list = list(set(full_list) - set(to_remove)) |
||||
|
# todo: elif format == 'tensorrt': |
||||
|
else: |
||||
|
log.error(f'Invalid format "{format}". Currently supported formats: "pytorch".') |
||||
|
return model_list |
Loading…
Reference in new issue