diff --git a/codebert.py b/codebert.py index c099494..bf3a22a 100644 --- a/codebert.py +++ b/codebert.py @@ -74,8 +74,7 @@ class CodeBert(NNOperator): def __call__(self, txt: str) -> numpy.ndarray: try: tokens = self.tokenizer.tokenize(txt) - tokens = [self.tokenizer.cls_token, '', self.tokenizer.sep_token] + tokens + \ - [self.tokenizer.sep_token] + tokens = [self.tokenizer.cls_token] + tokens + [self.tokenizer.sep_token] tokens_ids = self.tokenizer.convert_tokens_to_ids(tokens) inputs = torch.tensor(tokens_ids).unsqueeze(0).to(self.device) except Exception as e: