diff --git a/README.md b/README.md index 25e6b02..e307d2e 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,117 @@ -# codebert +# Code & Text Embedding with CodeBert + +*author: [Jael Gu](https://github.com/jaelgu)* + +
+ +## Description + +A code search operator takes a text string of programming language or natural language as an input +and returns an embedding vector in ndarray which captures the input's core semantic elements. +This operator is implemented with pre-trained models from [Huggingface Transformers](https://huggingface.co/docs/transformers). + +
+ +## Code Example + +Use the pre-trained model "huggingface/CodeBERTa-small-v1" +to generate text embeddings for given code "" and text description "". + +*Write the pipeline*: + +```python +import towhee + +( + towhee.dc([""]) + .text_embedding.transformers(model_name="distilbert-base-cased") +) +``` + +*Write a same pipeline with explicit inputs/outputs name specifications:* + +```python +import towhee + +( + towhee.dc['text'](["Hello, world."]) + .text_embedding.transformers['text', 'vec'](model_name="distilbert-base-cased") + .show() +) +``` + + + +
+ +## Factory Constructor + +Create the operator via the following factory method: + +***code_search.codebert(model_name="huggingface/CodeBERTa-small-v1")*** + +**Parameters:** + +***model_name***: *str* + +The model name in string. +The default model name is "huggingface/CodeBERTa-small-v1". + +***device***: *str* + +The device to run model inference. +The default value is None, which enables GPU if cuda is available. + +Supported model names: + + +
+ +## Interface + +The operator takes a piece of text in string as input. +It loads tokenizer and pre-trained model using model name. +and then return an embedding in ndarray. + +***__call__(txt)*** + +**Parameters:** + +***txt***: *str* + +​ The text string in programming language or natural language. + + +**Returns**: + +*numpy.ndarray* + +​ The text embedding generated by model, in shape of (dim,). + + +***save_model(format="pytorch", path="default")*** + +Save model to local with specified format. + +**Parameters:** + +***format***: *str* + +​ The format of saved model, defaults to "pytorch". + +***format***: *path* + +​ The path where model is saved to. By default, it will save model to the operator directory. + + +***supported_model_names(format=None)*** + +Get a list of all supported model names or supported model names for specified model format. + +**Parameters:** + +***format***: *str* + +​ The model format such as "pytorch", "torchscript". +The default value is None, which will return all supported model names. diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..037ec95 --- /dev/null +++ b/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2021 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .codebert import CodeBert + + +def codebert(**kwargs): + return CodeBert(**kwargs) diff --git a/codebert.py b/codebert.py new file mode 100644 index 0000000..2761f97 --- /dev/null +++ b/codebert.py @@ -0,0 +1,169 @@ +# Copyright 2021 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy +import os +import torch +from pathlib import Path +from transformers import AutoTokenizer, AutoModel + +from towhee.operator import NNOperator +from towhee import register + +import warnings +import logging + +warnings.filterwarnings('ignore') +logging.getLogger('transformers').setLevel(logging.ERROR) +log = logging.getLogger() + + +@register(output_schema=['vec']) +class AutoTransformers(NNOperator): + """ + An operator generates an embedding for code or natural language text + using a pretrained codebert model gathered by huggingface. + + Args: + model_name (`str`): + Which model to use for the embeddings. + device (`str`): + Device to run model inference. Defaults to None, enable GPU when it is available. + """ + + def __init__(self, model_name: str = 'huggingface/CodeBERTa-small-v1', device: str = None): + super().__init__() + self.model_name = model_name + self.modality = modality + assert modality in ['nlp', 'code'], 'Invalid modality value. Accept only "nlp" or "code".' + + if device is None: + device = 'cuda' if torch.cuda.is_available() else 'cpu' + self.device = device + + try: + self.model = AutoModel.from_pretrained(model_name).to(self.device) + self.model.eval() + except Exception as e: + model_list = self.supported_model_names() + if model_name not in model_list: + log.error(f'Invalid model name: {model_name}. Supported model names: {model_list}') + else: + log.error(f'Fail to load model by name: {self.model_name}') + raise e + try: + self.tokenizer = AutoTokenizer.from_pretrained(model_name) + except Exception as e: + log.error(f'Fail to load tokenizer by name: {self.model_name}') + raise e + + def __call__(self, txt: str) -> numpy.ndarray: + try: + inputs = self.tokenizer.encode(txt, return_tensors='pt').to(self.device) + except Exception as e: + log.error(f'Invalid input for the tokenizer: {self.model_name}') + raise e + try: + outs = self.model(inputs) + except Exception as e: + log.error(f'Invalid input for the model: {self.model_name}') + raise e + try: + features = outs['pooler_output'].squeeze(0) + except Exception as e: + log.error(f'Fail to extract features by model: {self.model_name}') + raise e + vec = features.cpu().detach().numpy() + return vec + + def save_model(self, format: str = 'pytorch', path: str = 'default'): + if path == 'default': + path = str(Path(__file__).parent) + path = os.path.join(path, 'saved', format) + os.makedirs(path, exist_ok=True) + name = self.model_name.replace('/', '-') + path = os.path.join(path, name) + inputs = self.tokenizer.encode('test', return_tensors='pt').to(self.device) # return a tensor of token ids + if format == 'pytorch': + path = path + '.pt' + torch.save(self.model, path) + elif format == 'torchscript': + path = path + '.pt' + inputs = list(inputs) + try: + try: + jit_model = torch.jit.script(self.model) + except Exception: + jit_model = torch.jit.trace(self.model, inputs, strict=False) + torch.jit.save(jit_model, path) + except Exception as e: + log.error(f'Fail to save as torchscript: {e}.') + raise RuntimeError(f'Fail to save as torchscript: {e}.') + elif format == 'onnx': + path = path + '.onnx' + try: + torch.onnx.export(self.model, + tuple(inputs), + path, + input_names=['input_ids'], # list(inputs.keys()) + output_names=['last_hidden_state', 'pooler_output'], + opset_version=12, + dynamic_axes={ + 'input_ids': {0: 'batch_size', 1: 'input_length'}, + 'last_hidden_state': {0: 'batch_size'}, + 'pooler_output': {0: 'batch_size', 1: 'output_dim'}, + }) + except Exception: + torch.onnx.export(self.model, + tuple(inputs.values()), + path, + input_names=['input_ids'], # list(inputs.keys()) + output_names=['last_hidden_state'], + opset_version=12, + dynamic_axes={ + 'input_ids': {0: 'batch_size', 1: 'input_length'}, + 'last_hidden_state': {0: 'batch_size'}, + }) + # todo: elif format == 'tensorrt': + else: + log.error(f'Unsupported format "{format}".') + + @staticmethod + def supported_model_names(format: str = None): + full_list = [ + 'huggingface/CodeBERTa-small-v1', + 'microsoft/codebert-base', + 'microsoft/codebert-base-mlm', + 'mrm8488/codebert-base-finetuned-stackoverflow-ner' + ] + full_list.sort() + if format is None: + model_list = full_list + elif format == 'pytorch': + to_remove = [] + assert set(to_remove).issubset(set(full_list)) + model_list = list(set(full_list) - set(to_remove)) + # todo: elif format == 'torchscript': + # to_remove = [ + # ] + # assert set(to_remove).issubset(set(full_list)) + # model_list = list(set(full_list) - set(to_remove)) + # todo: elif format == 'onnx': + # to_remove = [] + # assert set(to_remove).issubset(set(full_list)) + # model_list = list(set(full_list) - set(to_remove)) + # todo: elif format == 'tensorrt': + else: + log.error(f'Invalid format "{format}". Currently supported formats: "pytorch".') + return model_list diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..e514339 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,4 @@ +towhee +sentence-transformers +torch>=1.6.0 +transformers>=4.6.0 \ No newline at end of file