logo
Browse Source

update the opus-mt operator.

Signed-off-by: wxywb <xy.wang@zilliz.com>
main
wxywb 2 years ago
parent
commit
d9a9f920ca
  1. 31
      opus_mt.py

31
opus_mt.py

@ -25,30 +25,31 @@ from towhee import register
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
class OpusMT()
class OpusMT(NNOperator):
""" """
Opus-mt machine translation Opus-mt machine translation
""" """
def __init__(self, model_name: str): def __init__(self, model_name: str):
super().__init__() super().__init__()
self.device = "cuda" if torch.cuda.is_available() else "cpu" self.device = "cuda" if torch.cuda.is_available() else "cpu"
config = self.configs[model_name]
config = self.configs()[model_name]
self.tokenizer = AutoTokenizer.from_pretrained(config['tokenizer']) self.tokenizer = AutoTokenizer.from_pretrained(config['tokenizer'])
self.model = AutoModelForSeq2SeqLM.from_pretrained(config['model']) self.model = AutoModelForSeq2SeqLM.from_pretrained(config['model'])
self.model.to(self.device) self.model.to(self.device)
def __call__(self, data): def __call__(self, data):
input_ids = tokenizer(data, return_tensors='pt', padding=True)['input_ids'].to(self.device)
outputs = model.generate(input_ids)
decoded = tokenizer.decode(outputs[0].detach().cpu(), skip_special_tokens=True)
input_ids = self.tokenizer(data, return_tensors='pt', padding=True)['input_ids'].to(self.device)
outputs = self.model.generate(input_ids)
decoded = self.tokenizer.decode(outputs[0].detach().cpu(), skip_special_tokens=True)
return decoded return decoded
~
def _configs(self):
config = {}
config['opus-mt-en-zh'] = {}
config['opus-mt-en-zh']['tokenizer'] = 'Helsinki-NLP/opus-mt-en-zh'
config['opus-mt-en-zh']['model'] = 'Helsinki-NLP/opus-mt-en-zh'
config['opus-mt-zh-en'] = {}
config['opus-mt-zh-en']['tokenizer'] = 'Helsinki-NLP/opus-mt-zh-en'
config['opus-mt-zh-en']['model'] = 'Helsinki-NLP/opus-mt-zh-en'
def configs(self):
configs = {}
configs['opus-mt-en-zh'] = {}
configs['opus-mt-en-zh']['tokenizer'] = 'Helsinki-NLP/opus-mt-en-zh'
configs['opus-mt-en-zh']['model'] = 'Helsinki-NLP/opus-mt-en-zh'
configs['opus-mt-zh-en'] = {}
configs['opus-mt-zh-en']['tokenizer'] = 'Helsinki-NLP/opus-mt-zh-en'
configs['opus-mt-zh-en']['model'] = 'Helsinki-NLP/opus-mt-zh-en'
return configs

Loading…
Cancel
Save