unixcoder
copied
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions
173 lines
7.0 KiB
173 lines
7.0 KiB
# Copyright 2021 Zilliz. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import numpy
|
|
import os
|
|
import torch
|
|
from pathlib import Path
|
|
from transformers import RobertaTokenizer, RobertaModel, RobertaConfig
|
|
|
|
from towhee.operator import NNOperator
|
|
from towhee import register
|
|
|
|
import warnings
|
|
import logging
|
|
|
|
warnings.filterwarnings('ignore')
|
|
logging.getLogger('transformers').setLevel(logging.ERROR)
|
|
log = logging.getLogger()
|
|
|
|
|
|
@register(output_schema=['vec'])
|
|
class Unixcoder(NNOperator):
|
|
"""
|
|
An operator generates an embedding for code or natural language text
|
|
using a pretrained UniXcoder model gathered by huggingface.
|
|
|
|
Args:
|
|
model_name (`str`):
|
|
Which model to use for the embeddings.
|
|
device (`str`):
|
|
Device to run model inference. Defaults to None, enable GPU when it is available.
|
|
"""
|
|
|
|
def __init__(self, model_name: str = 'microsoft/unixcoder-base', device: str = None):
|
|
super().__init__()
|
|
self.model_name = model_name
|
|
|
|
if device is None:
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
self.device = device
|
|
|
|
try:
|
|
self.model = RobertaModel.from_pretrained(model_name).to(self.device)
|
|
self.model.eval()
|
|
except Exception as e:
|
|
model_list = self.supported_model_names()
|
|
if model_name not in model_list:
|
|
log.error(f'Invalid model name: {model_name}. Supported model names: {model_list}')
|
|
else:
|
|
log.error(f'Fail to load model by name: {self.model_name}')
|
|
raise e
|
|
try:
|
|
self.tokenizer = RobertaTokenizer.from_pretrained(model_name)
|
|
except Exception as e:
|
|
log.error(f'Fail to load tokenizer by name: {self.model_name}')
|
|
raise e
|
|
try:
|
|
self.configs = RobertaConfig.from_pretrained(model_name)
|
|
except Exception as e:
|
|
log.error(f'Fail to load configs by name: {self.model_name}')
|
|
raise e
|
|
|
|
def __call__(self, txt: str) -> numpy.ndarray:
|
|
try:
|
|
tokens = self.tokenizer.tokenize(txt)
|
|
tokens = [self.tokenizer.cls_token, '<encoder-only>', self.tokenizer.sep_token] + tokens + \
|
|
[self.tokenizer.sep_token]
|
|
tokens_ids = self.tokenizer.convert_tokens_to_ids(tokens)
|
|
inputs = torch.tensor(tokens_ids).unsqueeze(0).to(self.device)
|
|
except Exception as e:
|
|
log.error(f'Invalid input for the tokenizer: {self.model_name}')
|
|
raise e
|
|
try:
|
|
outs = self.model(inputs)
|
|
except Exception as e:
|
|
log.error(f'Invalid input for the model: {self.model_name}')
|
|
raise e
|
|
try:
|
|
features = outs['pooler_output'].squeeze(0)
|
|
except Exception as e:
|
|
log.error(f'Fail to extract features by model: {self.model_name}')
|
|
raise e
|
|
vec = features.cpu().detach().numpy()
|
|
return vec
|
|
|
|
def save_model(self, format: str = 'pytorch', path: str = 'default'):
|
|
if path == 'default':
|
|
path = str(Path(__file__).parent)
|
|
path = os.path.join(path, 'saved', format)
|
|
os.makedirs(path, exist_ok=True)
|
|
name = self.model_name.replace('/', '-')
|
|
path = os.path.join(path, name)
|
|
inputs = self.tokenizer.encode('test', return_tensors='pt').to(self.device) # return a tensor of token ids
|
|
if format == 'pytorch':
|
|
path = path + '.pt'
|
|
torch.save(self.model, path)
|
|
elif format == 'torchscript':
|
|
path = path + '.pt'
|
|
inputs = list(inputs)
|
|
try:
|
|
try:
|
|
jit_model = torch.jit.script(self.model)
|
|
except Exception:
|
|
jit_model = torch.jit.trace(self.model, inputs, strict=False)
|
|
torch.jit.save(jit_model, path)
|
|
except Exception as e:
|
|
log.error(f'Fail to save as torchscript: {e}.')
|
|
raise RuntimeError(f'Fail to save as torchscript: {e}.')
|
|
elif format == 'onnx':
|
|
path = path + '.onnx'
|
|
try:
|
|
torch.onnx.export(self.model,
|
|
tuple(inputs),
|
|
path,
|
|
input_names=['input_ids'], # list(inputs.keys())
|
|
output_names=['last_hidden_state', 'pooler_output'],
|
|
opset_version=12,
|
|
dynamic_axes={
|
|
'input_ids': {0: 'batch_size', 1: 'input_length'},
|
|
'last_hidden_state': {0: 'batch_size'},
|
|
'pooler_output': {0: 'batch_size', 1: 'output_dim'},
|
|
})
|
|
except Exception:
|
|
torch.onnx.export(self.model,
|
|
tuple(inputs.values()),
|
|
path,
|
|
input_names=['input_ids'], # list(inputs.keys())
|
|
output_names=['last_hidden_state'],
|
|
opset_version=12,
|
|
dynamic_axes={
|
|
'input_ids': {0: 'batch_size', 1: 'input_length'},
|
|
'last_hidden_state': {0: 'batch_size'},
|
|
})
|
|
# todo: elif format == 'tensorrt':
|
|
else:
|
|
log.error(f'Unsupported format "{format}".')
|
|
|
|
@staticmethod
|
|
def supported_model_names(format: str = None):
|
|
full_list = [
|
|
'microsoft/unixcoder-base'
|
|
]
|
|
full_list.sort()
|
|
if format is None:
|
|
model_list = full_list
|
|
elif format == 'pytorch':
|
|
to_remove = []
|
|
assert set(to_remove).issubset(set(full_list))
|
|
model_list = list(set(full_list) - set(to_remove))
|
|
# todo: elif format == 'torchscript':
|
|
# to_remove = [
|
|
# ]
|
|
# assert set(to_remove).issubset(set(full_list))
|
|
# model_list = list(set(full_list) - set(to_remove))
|
|
# todo: elif format == 'onnx':
|
|
# to_remove = []
|
|
# assert set(to_remove).issubset(set(full_list))
|
|
# model_list = list(set(full_list) - set(to_remove))
|
|
# todo: elif format == 'tensorrt':
|
|
else:
|
|
log.error(f'Invalid format "{format}". Currently supported formats: "pytorch".')
|
|
return model_list
|