opus-mt
copied
wxywb
2 years ago
3 changed files with 73 additions and 1 deletions
@ -1,2 +1,2 @@ |
|||
# opus-mt |
|||
# opus_mt |
|||
|
|||
|
@ -0,0 +1,18 @@ |
|||
# Copyright 2021 Zilliz. All rights reserved. |
|||
# |
|||
# Licensed under the Apache License, Version 2.0 (the "License"); |
|||
# you may not use this file except in compliance with the License. |
|||
# You may obtain a copy of the License at |
|||
# |
|||
# http://www.apache.org/licenses/LICENSE-2.0 |
|||
# |
|||
# Unless required by applicable law or agreed to in writing, software |
|||
# distributed under the License is distributed on an "AS IS" BASIS, |
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|||
# See the License for the specific language governing permissions and |
|||
# limitations under the License. |
|||
|
|||
from .opus_mt import OpusMT |
|||
|
|||
def opus_mt(model_name: str): |
|||
return OpusMT(model_name) |
@ -0,0 +1,54 @@ |
|||
# Copyright 2021 Zilliz. All rights reserved. |
|||
# |
|||
# Licensed under the Apache License, Version 2.0 (the "License"); |
|||
# you may not use this file except in compliance with the License. |
|||
# You may obtain a copy of the License at |
|||
# |
|||
# http://www.apache.org/licenses/LICENSE-2.0 |
|||
# |
|||
# Unless required by applicable law or agreed to in writing, software |
|||
# distributed under the License is distributed on an "AS IS" BASIS, |
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|||
# See the License for the specific language governing permissions and |
|||
# limitations under the License. |
|||
|
|||
import sys |
|||
from pathlib import Path |
|||
|
|||
import torch |
|||
from torchvision import transforms |
|||
|
|||
from towhee.types.image_utils import to_pil |
|||
from towhee.operator.base import NNOperator, OperatorFlag |
|||
from towhee.types.arg import arg, to_image_color |
|||
from towhee import register |
|||
|
|||
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|||
|
|||
class OpusMT() |
|||
""" |
|||
Opus-mt machine translation |
|||
""" |
|||
def __init__(self, model_name: str): |
|||
super().__init__() |
|||
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|||
config = self.configs[model_name] |
|||
self.tokenizer = AutoTokenizer.from_pretrained(config['tokenizer']) |
|||
self.model = AutoModelForSeq2SeqLM.from_pretrained(config['model']) |
|||
self.model.to(self.device) |
|||
|
|||
def __call__(self, data): |
|||
input_ids = tokenizer(data, return_tensors='pt', padding=True)['input_ids'].to(self.device) |
|||
outputs = model.generate(input_ids) |
|||
decoded = tokenizer.decode(outputs[0].detach().cpu(), skip_special_tokens=True) |
|||
return decoded |
|||
~ |
|||
def _configs(self): |
|||
config = {} |
|||
config['opus-mt-en-zh'] = {} |
|||
config['opus-mt-en-zh']['tokenizer'] = 'Helsinki-NLP/opus-mt-en-zh' |
|||
config['opus-mt-en-zh']['model'] = 'Helsinki-NLP/opus-mt-en-zh' |
|||
|
|||
config['opus-mt-zh-en'] = {} |
|||
config['opus-mt-zh-en']['tokenizer'] = 'Helsinki-NLP/opus-mt-zh-en' |
|||
config['opus-mt-zh-en']['model'] = 'Helsinki-NLP/opus-mt-zh-en' |
Loading…
Reference in new issue