diff --git a/auto_transformers.py b/auto_transformers.py index f7dc4d8..050cfb1 100644 --- a/auto_transformers.py +++ b/auto_transformers.py @@ -38,11 +38,14 @@ class AutoTransformers(NNOperator): Which model to use for the embeddings. """ - def __init__(self, model_name: str = "bert-base-uncased") -> None: + def __init__(self, model_name: str = "bert-base-uncased", device=None) -> None: super().__init__() + if device is None: + device = 'cuda' if torch.cuda.is_available() else 'cpu' + self.device = device self.model_name = model_name try: - self.model = AutoModel.from_pretrained(model_name) + self.model = AutoModel.from_pretrained(model_name).to(self.device) self.model.eval() except Exception as e: model_list = self.supported_model_names() @@ -59,7 +62,7 @@ class AutoTransformers(NNOperator): def __call__(self, txt: str) -> numpy.ndarray: try: - inputs = self.tokenizer(txt, return_tensors="pt") + inputs = self.tokenizer(txt, return_tensors="pt").to(self.device) except Exception as e: log.error(f'Invalid input for the tokenizer: {self.model_name}') raise e @@ -73,7 +76,7 @@ class AutoTransformers(NNOperator): except Exception as e: log.error(f'Fail to extract features by model: {self.model_name}') raise e - vec = features.detach().numpy() + vec = features.cpu().detach().numpy() return vec def save_model(self, format: str = 'pytorch', path: str = 'default'):