Browse Source
Update tokenizer
Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
1 changed files with
10 additions and
2 deletions
-
codebert.py
|
|
@ -16,7 +16,7 @@ import numpy |
|
|
|
import os |
|
|
|
import torch |
|
|
|
from pathlib import Path |
|
|
|
from transformers import AutoTokenizer, AutoModel |
|
|
|
from transformers import AutoConfig, AutoTokenizer, AutoModel |
|
|
|
|
|
|
|
from towhee.operator import NNOperator |
|
|
|
from towhee import register |
|
|
@ -65,10 +65,18 @@ class CodeBert(NNOperator): |
|
|
|
except Exception as e: |
|
|
|
log.error(f'Fail to load tokenizer by name: {self.model_name}') |
|
|
|
raise e |
|
|
|
try: |
|
|
|
self.configs = AutoConfig.from_pretrained(model_name) |
|
|
|
except Exception as e: |
|
|
|
log.error(f'Fail to load configs by name: {self.model_name}') |
|
|
|
raise e |
|
|
|
|
|
|
|
def __call__(self, txt: str) -> numpy.ndarray: |
|
|
|
try: |
|
|
|
inputs = self.tokenizer.encode(txt, return_tensors='pt').to(self.device) |
|
|
|
tokens = self.tokenizer.tokenize(txt) |
|
|
|
tokens = [tokenizer.cls_token, '<encoder-only>', tokenizer.sep_token] + tokens + [tokenizer.sep_token] |
|
|
|
tokens_ids = tokenizer.convert_tokens_to_ids(tokens, return_tensors='pt') |
|
|
|
inputs = torch.tensor(tokens_ids).to(self.device) |
|
|
|
except Exception as e: |
|
|
|
log.error(f'Invalid input for the tokenizer: {self.model_name}') |
|
|
|
raise e |
|
|
|