diff --git a/README.md b/README.md index 1abe899..81792b8 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,2 @@ -# opus-mt +# opus_mt diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..89f28f3 --- /dev/null +++ b/__init__.py @@ -0,0 +1,18 @@ +# Copyright 2021 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .opus_mt import OpusMT + +def opus_mt(model_name: str): + return OpusMT(model_name) diff --git a/opus_mt.py b/opus_mt.py new file mode 100644 index 0000000..0ff5566 --- /dev/null +++ b/opus_mt.py @@ -0,0 +1,54 @@ +# Copyright 2021 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +from pathlib import Path + +import torch +from torchvision import transforms + +from towhee.types.image_utils import to_pil +from towhee.operator.base import NNOperator, OperatorFlag +from towhee.types.arg import arg, to_image_color +from towhee import register + +from transformers import AutoTokenizer, AutoModelForSeq2SeqLM + +class OpusMT() + """ + Opus-mt machine translation + """ + def __init__(self, model_name: str): + super().__init__() + self.device = "cuda" if torch.cuda.is_available() else "cpu" + config = self.configs[model_name] + self.tokenizer = AutoTokenizer.from_pretrained(config['tokenizer']) + self.model = AutoModelForSeq2SeqLM.from_pretrained(config['model']) + self.model.to(self.device) + + def __call__(self, data): + input_ids = tokenizer(data, return_tensors='pt', padding=True)['input_ids'].to(self.device) + outputs = model.generate(input_ids) + decoded = tokenizer.decode(outputs[0].detach().cpu(), skip_special_tokens=True) + return decoded +~ + def _configs(self): + config = {} + config['opus-mt-en-zh'] = {} + config['opus-mt-en-zh']['tokenizer'] = 'Helsinki-NLP/opus-mt-en-zh' + config['opus-mt-en-zh']['model'] = 'Helsinki-NLP/opus-mt-en-zh' + + config['opus-mt-zh-en'] = {} + config['opus-mt-zh-en']['tokenizer'] = 'Helsinki-NLP/opus-mt-zh-en' + config['opus-mt-zh-en']['model'] = 'Helsinki-NLP/opus-mt-zh-en'