diff --git a/auto_transformers.py b/auto_transformers.py index 911319c..b4055ce 100644 --- a/auto_transformers.py +++ b/auto_transformers.py @@ -128,12 +128,12 @@ class AutoTransformers(NNOperator): log.error(f'Fail to tokenize inputs: {e}') raise e try: - outs = self.model(**inputs) + outs = self.model(**inputs).to('cpu') except Exception as e: log.error(f'Invalid input for the model: {self.model_name}') raise e outs = self.post_proc(outs, inputs) - features = outs.cpu().detach().numpy() + features = outs.detach().numpy() if isinstance(data, str): features = features.squeeze(0) else: