diff --git a/codebert.py b/codebert.py index a80ce6c..187395a 100644 --- a/codebert.py +++ b/codebert.py @@ -16,7 +16,7 @@ import numpy import os import torch from pathlib import Path -from transformers import AutoTokenizer, AutoModel +from transformers import AutoConfig, AutoTokenizer, AutoModel from towhee.operator import NNOperator from towhee import register @@ -65,10 +65,18 @@ class CodeBert(NNOperator): except Exception as e: log.error(f'Fail to load tokenizer by name: {self.model_name}') raise e + try: + self.configs = AutoConfig.from_pretrained(model_name) + except Exception as e: + log.error(f'Fail to load configs by name: {self.model_name}') + raise e def __call__(self, txt: str) -> numpy.ndarray: try: - inputs = self.tokenizer.encode(txt, return_tensors='pt').to(self.device) + tokens = self.tokenizer.tokenize(txt) + tokens = [tokenizer.cls_token, '', tokenizer.sep_token] + tokens + [tokenizer.sep_token] + tokens_ids = tokenizer.convert_tokens_to_ids(tokens, return_tensors='pt') + inputs = torch.tensor(tokens_ids).to(self.device) except Exception as e: log.error(f'Invalid input for the tokenizer: {self.model_name}') raise e