logo
Browse Source

Add device

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 2 years ago
parent
commit
908e07a327
  1. 11
      auto_transformers.py

11
auto_transformers.py

@ -38,11 +38,14 @@ class AutoTransformers(NNOperator):
Which model to use for the embeddings.
"""
def __init__(self, model_name: str = "bert-base-uncased") -> None:
def __init__(self, model_name: str = "bert-base-uncased", device=None) -> None:
super().__init__()
if device is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.device = device
self.model_name = model_name
try:
self.model = AutoModel.from_pretrained(model_name)
self.model = AutoModel.from_pretrained(model_name).to(self.device)
self.model.eval()
except Exception as e:
model_list = self.supported_model_names()
@ -59,7 +62,7 @@ class AutoTransformers(NNOperator):
def __call__(self, txt: str) -> numpy.ndarray:
try:
inputs = self.tokenizer(txt, return_tensors="pt")
inputs = self.tokenizer(txt, return_tensors="pt").to(self.device)
except Exception as e:
log.error(f'Invalid input for the tokenizer: {self.model_name}')
raise e
@ -73,7 +76,7 @@ class AutoTransformers(NNOperator):
except Exception as e:
log.error(f'Fail to extract features by model: {self.model_name}')
raise e
vec = features.detach().numpy()
vec = features.cpu().detach().numpy()
return vec
def save_model(self, format: str = 'pytorch', path: str = 'default'):

Loading…
Cancel
Save