# Copyright 2021 Zilliz. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import sys from pathlib import Path import torch from torchvision import transforms from towhee.types.image_utils import to_pil from towhee.operator.base import NNOperator, OperatorFlag from towhee.types.arg import arg, to_image_color from towhee import register from transformers import AutoTokenizer, AutoModelForSeq2SeqLM class OpusMT(NNOperator): """ Opus-mt machine translation """ def __init__(self, model_name: str): super().__init__() self.device = "cuda" if torch.cuda.is_available() else "cpu" config = self.configs()[model_name] self.tokenizer = AutoTokenizer.from_pretrained(config['tokenizer']) self.model = AutoModelForSeq2SeqLM.from_pretrained(config['model']) self.model.to(self.device) def __call__(self, data): input_ids = self.tokenizer(data, return_tensors='pt', padding=True)['input_ids'].to(self.device) outputs = self.model.generate(input_ids) decoded = self.tokenizer.decode(outputs[0].detach().cpu(), skip_special_tokens=True) return decoded def configs(self): configs = {} configs['opus-mt-en-zh'] = {} configs['opus-mt-en-zh']['tokenizer'] = 'Helsinki-NLP/opus-mt-en-zh' configs['opus-mt-en-zh']['model'] = 'Helsinki-NLP/opus-mt-en-zh' configs['opus-mt-zh-en'] = {} configs['opus-mt-zh-en']['tokenizer'] = 'Helsinki-NLP/opus-mt-zh-en' configs['opus-mt-zh-en']['model'] = 'Helsinki-NLP/opus-mt-zh-en' return configs