|
|
@ -22,6 +22,7 @@ from transformers import AutoTokenizer, AutoModel, AutoModelForMaskedLM, AutoMod |
|
|
|
|
|
|
|
from towhee.operator import NNOperator |
|
|
|
from towhee import register |
|
|
|
from towhee.dc2 import accelerate |
|
|
|
|
|
|
|
import warnings |
|
|
|
import logging |
|
|
@ -34,6 +35,26 @@ log = logging.getLogger('run_op') |
|
|
|
warnings.filterwarnings('ignore') |
|
|
|
|
|
|
|
|
|
|
|
@accelerate |
|
|
|
class Model: |
|
|
|
def __init__(self, model_name, device, checkpoint_path): |
|
|
|
try: |
|
|
|
self.model = AutoModel.from_pretrained(model_name).to(device) |
|
|
|
except Exception as e: |
|
|
|
log.error(f"Fail to load model by name: {self.model_name}") |
|
|
|
raise e |
|
|
|
if checkpoint_path: |
|
|
|
try: |
|
|
|
state_dict = torch.load(checkpoint_path, map_location=device) |
|
|
|
self.model.load_state_dict(state_dict) |
|
|
|
except Exception as e: |
|
|
|
log.error(f"Fail to load state dict from {checkpoint_path}: {e}") |
|
|
|
self.model.eval() |
|
|
|
|
|
|
|
def __call__(self, *args, **kwargs): |
|
|
|
return self.model(*args, **kwargs) |
|
|
|
|
|
|
|
|
|
|
|
@register(output_schema=['vec']) |
|
|
|
class AutoTransformers(NNOperator): |
|
|
|
""" |
|
|
@ -57,26 +78,14 @@ class AutoTransformers(NNOperator): |
|
|
|
if device is None: |
|
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
|
self.device = device |
|
|
|
|
|
|
|
model_list = self.supported_model_names() |
|
|
|
assert model_name in model_list, f"Invalid model name: {model_name}. Supported model names: {model_list}" |
|
|
|
self.model_name = model_name |
|
|
|
|
|
|
|
if self.model_name: |
|
|
|
model_list = self.supported_model_names() |
|
|
|
assert model_name in model_list, f"Invalid model name: {model_name}. Supported model names: {model_list}" |
|
|
|
|
|
|
|
try: |
|
|
|
self.model = AutoModel.from_pretrained(model_name).to(self.device) |
|
|
|
self.configs = self.model.config |
|
|
|
except Exception as e: |
|
|
|
log.error(f"Fail to load model by name: {self.model_name}") |
|
|
|
raise e |
|
|
|
if checkpoint_path: |
|
|
|
try: |
|
|
|
state_dict = torch.load(checkpoint_path, map_location=self.device) |
|
|
|
self.model.load_state_dict(state_dict) |
|
|
|
except Exception as e: |
|
|
|
log.error(f"Fail to load state dict from {checkpoint_path}: {e}") |
|
|
|
self.model.eval() |
|
|
|
|
|
|
|
self.model = Model(model_name=self.model_name, device=self.device, checkpoint_path=checkpoint_path) |
|
|
|
self.configs = self.model.model.config |
|
|
|
if tokenizer is None: |
|
|
|
try: |
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
@ -135,13 +144,13 @@ class AutoTransformers(NNOperator): |
|
|
|
from transformers.onnx.features import FeaturesManager |
|
|
|
from transformers.onnx import export |
|
|
|
model_kind, model_onnx_config = FeaturesManager.check_supported_model_or_raise( |
|
|
|
self.model, feature='default') |
|
|
|
onnx_config = model_onnx_config(self.model.config) |
|
|
|
self.model.model, feature='default') |
|
|
|
onnx_config = model_onnx_config(self.configs) |
|
|
|
if os.path.isdir(path): |
|
|
|
shutil.rmtree(path) |
|
|
|
onnx_inputs, onnx_outputs = export( |
|
|
|
self.tokenizer, |
|
|
|
self.model, |
|
|
|
self.model.model, |
|
|
|
config=onnx_config, |
|
|
|
opset=13, |
|
|
|
output=Path(path+'.onnx') |
|
|
|