diff --git a/codebert.py b/codebert.py index 187395a..386bc00 100644 --- a/codebert.py +++ b/codebert.py @@ -74,8 +74,9 @@ class CodeBert(NNOperator): def __call__(self, txt: str) -> numpy.ndarray: try: tokens = self.tokenizer.tokenize(txt) - tokens = [tokenizer.cls_token, '', tokenizer.sep_token] + tokens + [tokenizer.sep_token] - tokens_ids = tokenizer.convert_tokens_to_ids(tokens, return_tensors='pt') + tokens = [self.tokenizer.cls_token, '', self.tokenizer.sep_token] + tokens + \ + [self.tokenizer.sep_token] + tokens_ids = self.tokenizer.convert_tokens_to_ids(tokens, return_tensors='pt') inputs = torch.tensor(tokens_ids).to(self.device) except Exception as e: log.error(f'Invalid input for the tokenizer: {self.model_name}')