diff --git a/README.md b/README.md
index db09535..51887ec 100644
--- a/README.md
+++ b/README.md
@@ -1,2 +1,119 @@
-# unixcoder
+# Code & Text Embedding with UniXcoder
+
+*author: [Jael Gu](https://github.com/jaelgu)*
+
+
+
+## Description
+
+A code search operator takes a text string of programming language or natural language as an input
+and returns an embedding vector in ndarray which captures the input's core semantic elements.
+This operator is implemented with pre-trained [UniXcoder](https://arxiv.org/pdf/2203.03850.pdf) models
+from [Huggingface Transformers](https://huggingface.co/docs/transformers).
+
+
+
+## Code Example
+
+Use the pre-trained model "microsoft/unixcoder-base"
+to generate text embeddings for given
+text description "return max value" and code "def max(a,b): if a>b: return a else return b".
+
+*Write the pipeline*:
+
+```python
+import towhee
+
+(
+ towhee.dc(['find max value', 'def max(a,b): if a>b: return a else return b'])
+ .code_search.unixcoder(model_name='microsoft/unixcoder-base')
+)
+```
+
+*Write a same pipeline with explicit inputs/outputs name specifications:*
+
+```python
+import towhee
+
+(
+ towhee.dc['text'](['return max value', 'def max(a,b): if a>b: return a else return b'])
+ .code_search.unixcoder['text', 'embedding']()
+ .show()
+)
+```
+
+
+
+
+
+## Factory Constructor
+
+Create the operator via the following factory method:
+
+***code_search.unixcoder(model_name="microsoft/unixcoder-base")***
+
+**Parameters:**
+
+***model_name***: *str*
+
+The model name in string.
+The default model name is "microsoft/unixcoder-base".
+
+***device***: *str*
+
+The device to run model inference.
+The default value is None, which enables GPU if cuda is available.
+
+Supported model names:
+- microsoft/unixcoder-base
+
+
+
+## Interface
+
+The operator takes a piece of text in string as input.
+It loads tokenizer and pre-trained model using model name.
+and then return an embedding in ndarray.
+
+***__call__(txt)***
+
+**Parameters:**
+
+***txt***: *str*
+
+ The text string in programming language or natural language.
+
+
+**Returns**:
+
+*numpy.ndarray*
+
+ The text embedding generated by model, in shape of (dim,).
+
+
+***save_model(format="pytorch", path="default")***
+
+Save model to local with specified format.
+
+**Parameters:**
+
+***format***: *str*
+
+ The format of saved model, defaults to "pytorch".
+
+***format***: *path*
+
+ The path where model is saved to. By default, it will save model to the operator directory.
+
+
+***supported_model_names(format=None)***
+
+Get a list of all supported model names or supported model names for specified model format.
+
+**Parameters:**
+
+***format***: *str*
+
+ The model format such as "pytorch", "torchscript".
+The default value is None, which will return all supported model names.
diff --git a/__init__.py b/__init__.py
new file mode 100644
index 0000000..767040b
--- /dev/null
+++ b/__init__.py
@@ -0,0 +1,19 @@
+# Copyright 2021 Zilliz. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .unixcoder import Unixcoder
+
+
+def unixcoder(**kwargs):
+ return Unixcoder(**kwargs)
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000..e514339
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,4 @@
+towhee
+sentence-transformers
+torch>=1.6.0
+transformers>=4.6.0
\ No newline at end of file
diff --git a/unixcoder.py b/unixcoder.py
new file mode 100644
index 0000000..f0c74c7
--- /dev/null
+++ b/unixcoder.py
@@ -0,0 +1,173 @@
+# Copyright 2021 Zilliz. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy
+import os
+import torch
+from pathlib import Path
+from transformers import RobertaTokenizer, RobertaModel, RobertaConfig
+
+from towhee.operator import NNOperator
+from towhee import register
+
+import warnings
+import logging
+
+warnings.filterwarnings('ignore')
+logging.getLogger('transformers').setLevel(logging.ERROR)
+log = logging.getLogger()
+
+
+@register(output_schema=['vec'])
+class Unixcoder(NNOperator):
+ """
+ An operator generates an embedding for code or natural language text
+ using a pretrained UniXcoder model gathered by huggingface.
+
+ Args:
+ model_name (`str`):
+ Which model to use for the embeddings.
+ device (`str`):
+ Device to run model inference. Defaults to None, enable GPU when it is available.
+ """
+
+ def __init__(self, model_name: str = 'microsoft/unixcoder-base', device: str = None):
+ super().__init__()
+ self.model_name = model_name
+
+ if device is None:
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
+ self.device = device
+
+ try:
+ self.model = RobertaModel.from_pretrained(model_name).to(self.device)
+ self.model.eval()
+ except Exception as e:
+ model_list = self.supported_model_names()
+ if model_name not in model_list:
+ log.error(f'Invalid model name: {model_name}. Supported model names: {model_list}')
+ else:
+ log.error(f'Fail to load model by name: {self.model_name}')
+ raise e
+ try:
+ self.tokenizer = RobertaTokenizer.from_pretrained(model_name)
+ except Exception as e:
+ log.error(f'Fail to load tokenizer by name: {self.model_name}')
+ raise e
+ try:
+ self.configs = RobertaConfig.from_pretrained(model_name)
+ except Exception as e:
+ log.error(f'Fail to load configs by name: {self.model_name}')
+ raise e
+
+ def __call__(self, txt: str) -> numpy.ndarray:
+ try:
+ tokens = self.tokenizer.tokenize(txt)
+ tokens = [self.tokenizer.cls_token, '', self.tokenizer.sep_token] + tokens + \
+ [self.tokenizer.sep_token]
+ tokens_ids = self.tokenizer.convert_tokens_to_ids(tokens)
+ inputs = torch.tensor(tokens_ids).unsqueeze(0).to(self.device)
+ except Exception as e:
+ log.error(f'Invalid input for the tokenizer: {self.model_name}')
+ raise e
+ try:
+ outs = self.model(inputs)
+ except Exception as e:
+ log.error(f'Invalid input for the model: {self.model_name}')
+ raise e
+ try:
+ features = outs['pooler_output'].squeeze(0)
+ except Exception as e:
+ log.error(f'Fail to extract features by model: {self.model_name}')
+ raise e
+ vec = features.cpu().detach().numpy()
+ return vec
+
+ def save_model(self, format: str = 'pytorch', path: str = 'default'):
+ if path == 'default':
+ path = str(Path(__file__).parent)
+ path = os.path.join(path, 'saved', format)
+ os.makedirs(path, exist_ok=True)
+ name = self.model_name.replace('/', '-')
+ path = os.path.join(path, name)
+ inputs = self.tokenizer.encode('test', return_tensors='pt').to(self.device) # return a tensor of token ids
+ if format == 'pytorch':
+ path = path + '.pt'
+ torch.save(self.model, path)
+ elif format == 'torchscript':
+ path = path + '.pt'
+ inputs = list(inputs)
+ try:
+ try:
+ jit_model = torch.jit.script(self.model)
+ except Exception:
+ jit_model = torch.jit.trace(self.model, inputs, strict=False)
+ torch.jit.save(jit_model, path)
+ except Exception as e:
+ log.error(f'Fail to save as torchscript: {e}.')
+ raise RuntimeError(f'Fail to save as torchscript: {e}.')
+ elif format == 'onnx':
+ path = path + '.onnx'
+ try:
+ torch.onnx.export(self.model,
+ tuple(inputs),
+ path,
+ input_names=['input_ids'], # list(inputs.keys())
+ output_names=['last_hidden_state', 'pooler_output'],
+ opset_version=12,
+ dynamic_axes={
+ 'input_ids': {0: 'batch_size', 1: 'input_length'},
+ 'last_hidden_state': {0: 'batch_size'},
+ 'pooler_output': {0: 'batch_size', 1: 'output_dim'},
+ })
+ except Exception:
+ torch.onnx.export(self.model,
+ tuple(inputs.values()),
+ path,
+ input_names=['input_ids'], # list(inputs.keys())
+ output_names=['last_hidden_state'],
+ opset_version=12,
+ dynamic_axes={
+ 'input_ids': {0: 'batch_size', 1: 'input_length'},
+ 'last_hidden_state': {0: 'batch_size'},
+ })
+ # todo: elif format == 'tensorrt':
+ else:
+ log.error(f'Unsupported format "{format}".')
+
+ @staticmethod
+ def supported_model_names(format: str = None):
+ full_list = [
+ 'microsoft/unixcoder-base'
+ ]
+ full_list.sort()
+ if format is None:
+ model_list = full_list
+ elif format == 'pytorch':
+ to_remove = []
+ assert set(to_remove).issubset(set(full_list))
+ model_list = list(set(full_list) - set(to_remove))
+ # todo: elif format == 'torchscript':
+ # to_remove = [
+ # ]
+ # assert set(to_remove).issubset(set(full_list))
+ # model_list = list(set(full_list) - set(to_remove))
+ # todo: elif format == 'onnx':
+ # to_remove = []
+ # assert set(to_remove).issubset(set(full_list))
+ # model_list = list(set(full_list) - set(to_remove))
+ # todo: elif format == 'tensorrt':
+ else:
+ log.error(f'Invalid format "{format}". Currently supported formats: "pytorch".')
+ return model_list