logo
Browse Source

Convert sbert names

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 2 years ago
parent
commit
137038d4bb
  1. 17
      auto_transformers.py

17
auto_transformers.py

@ -14,6 +14,7 @@
import numpy import numpy
import os import os
import requests
import torch import torch
import shutil import shutil
from pathlib import Path from pathlib import Path
@ -71,7 +72,7 @@ class AutoTransformers(NNOperator):
self.device = device self.device = device
else: else:
self.device = 'cuda' if torch.cuda.is_available() else 'cpu' self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.model_name = model_name
self.model_name = self.map_model_names(model_name)
if tokenizer: if tokenizer:
self.tokenizer = tokenizer self.tokenizer = tokenizer
else: else:
@ -242,3 +243,17 @@ class AutoTransformers(NNOperator):
else: else:
log.error(f'Invalid format "{format}". Currently supported formats: "pytorch", "torchscript".') log.error(f'Invalid format "{format}". Currently supported formats: "pytorch", "torchscript".')
return model_list return model_list
@staticmethod
def map_model_names(name):
req = requests.get("https://www.sbert.net/_static/html/models_en_sentence_embeddings.html")
data = req.text
default_sbert = []
for line in data.split('\r\n'):
line = line.replace(' ', '')
if line.startswith('"name":'):
name = line.split(':')[-1].replace('"', '').replace(',', '')
default_sbert.append(name)
if name in default_sbert:
name = 'sentence-transformers/' + name
return name

Loading…
Cancel
Save