From 137038d4bb155bb7ca2ada406b401c2dfaeae6f1 Mon Sep 17 00:00:00 2001 From: Jael Gu Date: Thu, 12 Jan 2023 16:38:53 +0800 Subject: [PATCH] Convert sbert names Signed-off-by: Jael Gu --- auto_transformers.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/auto_transformers.py b/auto_transformers.py index 6104820..e0933ef 100644 --- a/auto_transformers.py +++ b/auto_transformers.py @@ -14,6 +14,7 @@ import numpy import os +import requests import torch import shutil from pathlib import Path @@ -71,7 +72,7 @@ class AutoTransformers(NNOperator): self.device = device else: self.device = 'cuda' if torch.cuda.is_available() else 'cpu' - self.model_name = model_name + self.model_name = self.map_model_names(model_name) if tokenizer: self.tokenizer = tokenizer else: @@ -242,3 +243,17 @@ class AutoTransformers(NNOperator): else: log.error(f'Invalid format "{format}". Currently supported formats: "pytorch", "torchscript".') return model_list + + @staticmethod + def map_model_names(name): + req = requests.get("https://www.sbert.net/_static/html/models_en_sentence_embeddings.html") + data = req.text + default_sbert = [] + for line in data.split('\r\n'): + line = line.replace(' ', '') + if line.startswith('"name":'): + name = line.split(':')[-1].replace('"', '').replace(',', '') + default_sbert.append(name) + if name in default_sbert: + name = 'sentence-transformers/' + name + return name