opus-mt
copied
wxywb
2 years ago
3 changed files with 73 additions and 1 deletions
@ -1,2 +1,2 @@ |
|||||
# opus-mt |
|
||||
|
# opus_mt |
||||
|
|
||||
|
@ -0,0 +1,18 @@ |
|||||
|
# Copyright 2021 Zilliz. All rights reserved. |
||||
|
# |
||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
|
# you may not use this file except in compliance with the License. |
||||
|
# You may obtain a copy of the License at |
||||
|
# |
||||
|
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
|
# |
||||
|
# Unless required by applicable law or agreed to in writing, software |
||||
|
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
|
# See the License for the specific language governing permissions and |
||||
|
# limitations under the License. |
||||
|
|
||||
|
from .opus_mt import OpusMT |
||||
|
|
||||
|
def opus_mt(model_name: str): |
||||
|
return OpusMT(model_name) |
@ -0,0 +1,54 @@ |
|||||
|
# Copyright 2021 Zilliz. All rights reserved. |
||||
|
# |
||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
|
# you may not use this file except in compliance with the License. |
||||
|
# You may obtain a copy of the License at |
||||
|
# |
||||
|
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
|
# |
||||
|
# Unless required by applicable law or agreed to in writing, software |
||||
|
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
|
# See the License for the specific language governing permissions and |
||||
|
# limitations under the License. |
||||
|
|
||||
|
import sys |
||||
|
from pathlib import Path |
||||
|
|
||||
|
import torch |
||||
|
from torchvision import transforms |
||||
|
|
||||
|
from towhee.types.image_utils import to_pil |
||||
|
from towhee.operator.base import NNOperator, OperatorFlag |
||||
|
from towhee.types.arg import arg, to_image_color |
||||
|
from towhee import register |
||||
|
|
||||
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
||||
|
|
||||
|
class OpusMT() |
||||
|
""" |
||||
|
Opus-mt machine translation |
||||
|
""" |
||||
|
def __init__(self, model_name: str): |
||||
|
super().__init__() |
||||
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
||||
|
config = self.configs[model_name] |
||||
|
self.tokenizer = AutoTokenizer.from_pretrained(config['tokenizer']) |
||||
|
self.model = AutoModelForSeq2SeqLM.from_pretrained(config['model']) |
||||
|
self.model.to(self.device) |
||||
|
|
||||
|
def __call__(self, data): |
||||
|
input_ids = tokenizer(data, return_tensors='pt', padding=True)['input_ids'].to(self.device) |
||||
|
outputs = model.generate(input_ids) |
||||
|
decoded = tokenizer.decode(outputs[0].detach().cpu(), skip_special_tokens=True) |
||||
|
return decoded |
||||
|
~ |
||||
|
def _configs(self): |
||||
|
config = {} |
||||
|
config['opus-mt-en-zh'] = {} |
||||
|
config['opus-mt-en-zh']['tokenizer'] = 'Helsinki-NLP/opus-mt-en-zh' |
||||
|
config['opus-mt-en-zh']['model'] = 'Helsinki-NLP/opus-mt-en-zh' |
||||
|
|
||||
|
config['opus-mt-zh-en'] = {} |
||||
|
config['opus-mt-zh-en']['tokenizer'] = 'Helsinki-NLP/opus-mt-zh-en' |
||||
|
config['opus-mt-zh-en']['model'] = 'Helsinki-NLP/opus-mt-zh-en' |
Loading…
Reference in new issue