|
|
@ -25,30 +25,31 @@ from towhee import register |
|
|
|
|
|
|
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
|
|
|
|
|
|
class OpusMT() |
|
|
|
class OpusMT(NNOperator): |
|
|
|
""" |
|
|
|
Opus-mt machine translation |
|
|
|
""" |
|
|
|
def __init__(self, model_name: str): |
|
|
|
super().__init__() |
|
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
config = self.configs[model_name] |
|
|
|
config = self.configs()[model_name] |
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(config['tokenizer']) |
|
|
|
self.model = AutoModelForSeq2SeqLM.from_pretrained(config['model']) |
|
|
|
self.model.to(self.device) |
|
|
|
|
|
|
|
def __call__(self, data): |
|
|
|
input_ids = tokenizer(data, return_tensors='pt', padding=True)['input_ids'].to(self.device) |
|
|
|
outputs = model.generate(input_ids) |
|
|
|
decoded = tokenizer.decode(outputs[0].detach().cpu(), skip_special_tokens=True) |
|
|
|
input_ids = self.tokenizer(data, return_tensors='pt', padding=True)['input_ids'].to(self.device) |
|
|
|
outputs = self.model.generate(input_ids) |
|
|
|
decoded = self.tokenizer.decode(outputs[0].detach().cpu(), skip_special_tokens=True) |
|
|
|
return decoded |
|
|
|
~ |
|
|
|
def _configs(self): |
|
|
|
config = {} |
|
|
|
config['opus-mt-en-zh'] = {} |
|
|
|
config['opus-mt-en-zh']['tokenizer'] = 'Helsinki-NLP/opus-mt-en-zh' |
|
|
|
config['opus-mt-en-zh']['model'] = 'Helsinki-NLP/opus-mt-en-zh' |
|
|
|
|
|
|
|
config['opus-mt-zh-en'] = {} |
|
|
|
config['opus-mt-zh-en']['tokenizer'] = 'Helsinki-NLP/opus-mt-zh-en' |
|
|
|
config['opus-mt-zh-en']['model'] = 'Helsinki-NLP/opus-mt-zh-en' |
|
|
|
|
|
|
|
def configs(self): |
|
|
|
configs = {} |
|
|
|
configs['opus-mt-en-zh'] = {} |
|
|
|
configs['opus-mt-en-zh']['tokenizer'] = 'Helsinki-NLP/opus-mt-en-zh' |
|
|
|
configs['opus-mt-en-zh']['model'] = 'Helsinki-NLP/opus-mt-en-zh' |
|
|
|
|
|
|
|
configs['opus-mt-zh-en'] = {} |
|
|
|
configs['opus-mt-zh-en']['tokenizer'] = 'Helsinki-NLP/opus-mt-zh-en' |
|
|
|
configs['opus-mt-zh-en']['model'] = 'Helsinki-NLP/opus-mt-zh-en' |
|
|
|
return configs |
|
|
|