logo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

56 lines
2.1 KiB

# Copyright 2021 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
from pathlib import Path
import torch
from torchvision import transforms
from towhee.types.image_utils import to_pil
from towhee.operator.base import NNOperator, OperatorFlag
from towhee.types.arg import arg, to_image_color
from towhee import register
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
class OpusMT(NNOperator):
"""
Opus-mt machine translation
"""
def __init__(self, model_name: str):
super().__init__()
self.device = "cuda" if torch.cuda.is_available() else "cpu"
config = self.configs()[model_name]
self.tokenizer = AutoTokenizer.from_pretrained(config['tokenizer'])
self.model = AutoModelForSeq2SeqLM.from_pretrained(config['model'])
self.model.to(self.device)
def __call__(self, data):
input_ids = self.tokenizer(data, return_tensors='pt', padding=True)['input_ids'].to(self.device)
outputs = self.model.generate(input_ids)
decoded = self.tokenizer.decode(outputs[0].detach().cpu(), skip_special_tokens=True)
return decoded
def configs(self):
configs = {}
configs['opus-mt-en-zh'] = {}
configs['opus-mt-en-zh']['tokenizer'] = 'Helsinki-NLP/opus-mt-en-zh'
configs['opus-mt-en-zh']['model'] = 'Helsinki-NLP/opus-mt-en-zh'
configs['opus-mt-zh-en'] = {}
configs['opus-mt-zh-en']['tokenizer'] = 'Helsinki-NLP/opus-mt-zh-en'
configs['opus-mt-zh-en']['model'] = 'Helsinki-NLP/opus-mt-zh-en'
return configs