Browse Source
Convert sbert names
Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
1 changed files with
16 additions and
1 deletions
-
auto_transformers.py
|
|
@ -14,6 +14,7 @@ |
|
|
|
|
|
|
|
import numpy |
|
|
|
import os |
|
|
|
import requests |
|
|
|
import torch |
|
|
|
import shutil |
|
|
|
from pathlib import Path |
|
|
@ -71,7 +72,7 @@ class AutoTransformers(NNOperator): |
|
|
|
self.device = device |
|
|
|
else: |
|
|
|
self.device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
|
self.model_name = model_name |
|
|
|
self.model_name = self.map_model_names(model_name) |
|
|
|
if tokenizer: |
|
|
|
self.tokenizer = tokenizer |
|
|
|
else: |
|
|
@ -242,3 +243,17 @@ class AutoTransformers(NNOperator): |
|
|
|
else: |
|
|
|
log.error(f'Invalid format "{format}". Currently supported formats: "pytorch", "torchscript".') |
|
|
|
return model_list |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def map_model_names(name): |
|
|
|
req = requests.get("https://www.sbert.net/_static/html/models_en_sentence_embeddings.html") |
|
|
|
data = req.text |
|
|
|
default_sbert = [] |
|
|
|
for line in data.split('\r\n'): |
|
|
|
line = line.replace(' ', '') |
|
|
|
if line.startswith('"name":'): |
|
|
|
name = line.split(':')[-1].replace('"', '').replace(',', '') |
|
|
|
default_sbert.append(name) |
|
|
|
if name in default_sbert: |
|
|
|
name = 'sentence-transformers/' + name |
|
|
|
return name |
|
|
|