diff --git a/codebert.py b/codebert.py index 274e1a2..c099494 100644 --- a/codebert.py +++ b/codebert.py @@ -77,7 +77,7 @@ class CodeBert(NNOperator): tokens = [self.tokenizer.cls_token, '', self.tokenizer.sep_token] + tokens + \ [self.tokenizer.sep_token] tokens_ids = self.tokenizer.convert_tokens_to_ids(tokens) - inputs = torch.tensor(tokens_ids).to(self.device) + inputs = torch.tensor(tokens_ids).unsqueeze(0).to(self.device) except Exception as e: log.error(f'Invalid input for the tokenizer: {self.model_name}') raise e