diff --git a/opus_mt.py b/opus_mt.py index 0ff5566..df14415 100644 --- a/opus_mt.py +++ b/opus_mt.py @@ -25,30 +25,31 @@ from towhee import register from transformers import AutoTokenizer, AutoModelForSeq2SeqLM -class OpusMT() +class OpusMT(NNOperator): """ Opus-mt machine translation """ def __init__(self, model_name: str): super().__init__() self.device = "cuda" if torch.cuda.is_available() else "cpu" - config = self.configs[model_name] + config = self.configs()[model_name] self.tokenizer = AutoTokenizer.from_pretrained(config['tokenizer']) self.model = AutoModelForSeq2SeqLM.from_pretrained(config['model']) self.model.to(self.device) def __call__(self, data): - input_ids = tokenizer(data, return_tensors='pt', padding=True)['input_ids'].to(self.device) - outputs = model.generate(input_ids) - decoded = tokenizer.decode(outputs[0].detach().cpu(), skip_special_tokens=True) + input_ids = self.tokenizer(data, return_tensors='pt', padding=True)['input_ids'].to(self.device) + outputs = self.model.generate(input_ids) + decoded = self.tokenizer.decode(outputs[0].detach().cpu(), skip_special_tokens=True) return decoded -~ - def _configs(self): - config = {} - config['opus-mt-en-zh'] = {} - config['opus-mt-en-zh']['tokenizer'] = 'Helsinki-NLP/opus-mt-en-zh' - config['opus-mt-en-zh']['model'] = 'Helsinki-NLP/opus-mt-en-zh' - - config['opus-mt-zh-en'] = {} - config['opus-mt-zh-en']['tokenizer'] = 'Helsinki-NLP/opus-mt-zh-en' - config['opus-mt-zh-en']['model'] = 'Helsinki-NLP/opus-mt-zh-en' + + def configs(self): + configs = {} + configs['opus-mt-en-zh'] = {} + configs['opus-mt-en-zh']['tokenizer'] = 'Helsinki-NLP/opus-mt-en-zh' + configs['opus-mt-en-zh']['model'] = 'Helsinki-NLP/opus-mt-en-zh' + + configs['opus-mt-zh-en'] = {} + configs['opus-mt-zh-en']['tokenizer'] = 'Helsinki-NLP/opus-mt-zh-en' + configs['opus-mt-zh-en']['model'] = 'Helsinki-NLP/opus-mt-zh-en' + return configs