# Copyright 2021 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy
import os
import requests
import torch
import shutil
from pathlib import Path
from typing import Union
from collections import OrderedDict

from transformers import AutoTokenizer, AutoConfig, AutoModel, AutoModelForMaskedLM, AutoModelForCausalLM

from towhee.operator import NNOperator
from towhee import register
# from towhee.dc2 import accelerate

import warnings
import logging
from transformers import logging as t_logging

from .train_mlm_with_hf_trainer import train_mlm_with_hf_trainer
from .train_clm_with_hf_trainer import train_clm_with_hf_trainer

log = logging.getLogger('run_op')
warnings.filterwarnings('ignore')
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

log.setLevel(logging.ERROR)
t_logging.set_verbosity_error()


def create_model(model_name, checkpoint_path, device):
    model = AutoModel.from_pretrained(model_name).to(device)
    if hasattr(model, 'pooler') and model.pooler:
        model.pooler = None
    if checkpoint_path:
        try:
            state_dict = torch.load(checkpoint_path, map_location=device)
            model.load_state_dict(state_dict)
        except Exception:
            log.error(f'Fail to load weights from {checkpoint_path}')
    model.eval()
    return model


# @accelerate
class Model:
    def __init__(self, model_name, checkpoint_path, device):
        self.device = device
        self.model = create_model(model_name, checkpoint_path, device)

    def __call__(self, *args, **kwargs):
        new_args = []
        for x in args:
            new_args.append(x.to(self.device))
        new_kwargs = {}
        for k, v in kwargs.items():
            new_kwargs[k] = v.to(self.device)
        outs = self.model(*new_args, **new_kwargs, return_dict=True)
        return outs['last_hidden_state']


@register(output_schema=['vec'])
class AutoTransformers(NNOperator):
    """
    NLP embedding operator that uses the pretrained transformers model gathered by huggingface.
    Args:
        model_name (`str`):
            The model name to load a pretrained model from transformers.
        checkpoint_path (`str`):
            The local checkpoint path.
        tokenizer (`object`):
            The tokenizer to tokenize input text as model inputs.
    """

    def __init__(self,
                 model_name: str = None,
                 checkpoint_path: str = None,
                 tokenizer: object = None,
                 device: str = None,
                 ):
        super().__init__()
        if device:
            self.device = device
        else:
            self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        if model_name in s_list:
            self.model_name = 'sentence-transformers/' + model_name
        else:
            self.model_name = model_name
        self.checkpoint_path = checkpoint_path

        if self.model_name:
            # model_list = self.supported_model_names()
            # assert model_name in model_list, f"Invalid model name: {model_name}. Supported model names: {model_list}"
            self.model = Model(model_name=self.model_name, checkpoint_path=self.checkpoint_path, device=self.device)
            if tokenizer:
                self.tokenizer = tokenizer
            else:
                self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
            if not self.tokenizer.pad_token:
                self.tokenizer.pad_token = '[PAD]'
        else:
            log.warning('The operator is initialized without specified model.')
            pass

    def __call__(self, data: Union[str, list]) -> numpy.ndarray:
        if isinstance(data, str):
            txt = [data]
        else:
            txt = data
        try:
            inputs = self.tokenizer(txt, padding=True, truncation=True, return_tensors='pt')
        except Exception as e:
            log.error(f'Fail to tokenize inputs: {e}')
            raise e
        try:
            outs = self.model(**inputs).to('cpu')
        except Exception as e:
            log.error(f'Invalid input for the model: {self.model_name}')
            raise e
        outs = self.post_proc(outs, inputs)
        features = outs.detach().numpy()
        if isinstance(data, str):
            features = features.squeeze(0)
        else:
            features = list(features)
        return features

    @property
    def _model(self):
        model = self.model.model
        return model

    @property
    def model_config(self):
        configs = AutoConfig.from_pretrained(self.model_name)
        return configs

    @property
    def onnx_config(self):
        from transformers.onnx.features import FeaturesManager
        try:
            model_kind, model_onnx_config = FeaturesManager.check_supported_model_or_raise(
                self._model, feature='default')
            old_config = model_onnx_config(self.model_config)
            onnx_config = {
                'inputs': dict(old_config.inputs),
                'outputs': {'last_hidden_state': old_config.outputs['last_hidden_state']}
            }
        except Exception:
            input_dict = {}
            for k in self.tokenizer.model_input_names:
                input_dict[k] = {0: 'batch_size', 1: 'sequence_length'}
            onnx_config = {
                'inputs': input_dict,
                'outputs': {'last_hidden_state': {0: 'batch_size', 1: 'sequence_length'}}
            }
        return onnx_config

    def post_proc(self, token_embeddings, inputs):
        token_embeddings = token_embeddings
        attention_mask = inputs['attention_mask']
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        sentence_embs = torch.sum(
            token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
        return sentence_embs

    def save_model(self, model_type: str = 'pytorch', output_file: str = 'default'):
        if output_file == 'default':
            output_file = str(Path(__file__).parent)
            output_file = os.path.join(output_file, 'saved', model_type)
            os.makedirs(output_file, exist_ok=True)
            name = self.model_name.replace('/', '-')
            output_file = os.path.join(output_file, name)
            if model_type in ['pytorch', 'torchscript']:
                output_file = output_file + '.pt'
            elif model_type == 'onnx':
                output_file = output_file + '.onnx'
            else:
                raise AttributeError('Unsupported model_type.')

        dummy_input = 'test sentence'
        inputs = self.tokenizer(dummy_input, padding=True, truncation=True, return_tensors='pt')
        if model_type == 'pytorch':
            torch.save(self._model, output_file)
        elif model_type == 'torchscript':
            inputs = list(inputs.values())
            try:
                try:
                    jit_model = torch.jit.script(self._model)
                except Exception:
                    jit_model = torch.jit.trace(self._model, inputs, strict=False)
                torch.jit.save(jit_model, output_file)
            except Exception as e:
                log.error(f'Fail to save as torchscript: {e}.')
                raise RuntimeError(f'Fail to save as torchscript: {e}.')
        elif model_type == 'onnx':
            dynamic_axes = {}
            for k, v in self.onnx_config['inputs'].items():
                dynamic_axes[k] = v
            for k, v in self.onnx_config['outputs'].items():
                dynamic_axes[k] = v
            if hasattr(self._model.config, 'use_cache'):
                self._model.config.use_cache = False
            torch.onnx.export(
                self._model.to('cpu'),
                tuple(inputs.values()),
                output_file,
                input_names=list(self.onnx_config['inputs'].keys()),
                output_names=list(self.onnx_config['outputs'].keys()),
                dynamic_axes=dynamic_axes,
                opset_version=torch.onnx.constant_folding_opset_versions[-1] if hasattr(
                    torch.onnx, 'constant_folding_opset_versions') else 14,
                do_constant_folding=True,
            )
        # todo: elif format == 'tensorrt':
        else:
            log.error(f'Unsupported format "{format}".')
        return Path(output_file).resolve()

    @property
    def supported_formats(self):
        return ['onnx']

    @staticmethod
    def supported_model_names(format: str = None):
        add_models = [
            'bert-base-uncased',
            'bert-large-uncased',
            'bert-large-uncased-whole-word-masking',
            'distilbert-base-uncased',
            'facebook/bart-large',
            'gpt2-xl',
            'microsoft/deberta-xlarge',
            'microsoft/deberta-xlarge-mnli',
        ]
        full_list = s_list + add_models
        full_list.sort()
        if format is None:
            model_list = full_list
        elif format == 'pytorch':
            to_remove = []
            assert set(to_remove).issubset(set(full_list))
            model_list = list(set(full_list) - set(to_remove))
        elif format == 'torchscript':
            to_remove = []
            assert set(to_remove).issubset(set(full_list))
            model_list = list(set(full_list) - set(to_remove))
        elif format == 'onnx':
            to_remove = ['gpt2-xl']
            assert set(to_remove).issubset(set(full_list))
            model_list = list(set(full_list) - set(to_remove))
        # todo: elif format == 'tensorrt':
        else:
            log.error(f'Invalid format "{format}". Currently supported formats: "pytorch", "torchscript".')
        return model_list

    def train(self, training_config=None,
              train_dataset=None,
              eval_dataset=None,
              resume_checkpoint_path=None, **kwargs):
        task = kwargs.pop('task', None)
        data_args = kwargs.pop('data_args', None)
        training_args = kwargs.pop('training_args', None)
        prepare_model_weights_f = kwargs.pop('prepare_model_weights_f', None)

        if task == 'mlm' or task is None:
            model_with_head = AutoModelForMaskedLM.from_pretrained(self.model_name)
            if prepare_model_weights_f is not None:
                model_with_head = prepare_model_weights_f(self._model, model_with_head, **kwargs)

            train_mlm_with_hf_trainer(
                model_with_head,
                self.tokenizer,
                data_args,
                training_args,
                **kwargs
            )
        elif task == 'clm':
            model_with_head = AutoModelForCausalLM.from_pretrained(self.model_name)
            if prepare_model_weights_f is not None:
                model_with_head = prepare_model_weights_f(self._model, model_with_head, **kwargs)

            train_clm_with_hf_trainer(
                model_with_head,
                self.tokenizer,
                data_args,
                training_args,
                **kwargs
            )

s_list = [
    'paraphrase-MiniLM-L3-v2',
    'paraphrase-MiniLM-L6-v2',
    'paraphrase-MiniLM-L12-v2',
    'paraphrase-distilroberta-base-v2',
    'paraphrase-TinyBERT-L6-v2',
    'paraphrase-mpnet-base-v2',
    'paraphrase-albert-small-v2',
    'paraphrase-multilingual-mpnet-base-v2',
    'paraphrase-multilingual-MiniLM-L12-v2',
    'distiluse-base-multilingual-cased-v1',
    'distiluse-base-multilingual-cased-v2',
    'all-distilroberta-v1',
    'all-MiniLM-L6-v1',
    'all-MiniLM-L6-v2',
    'all-MiniLM-L12-v1',
    'all-MiniLM-L12-v2',
    'all-mpnet-base-v1',
    'all-mpnet-base-v2',
    'all-roberta-large-v1',
    'multi-qa-MiniLM-L6-dot-v1',
    'multi-qa-MiniLM-L6-cos-v1',
    'multi-qa-distilbert-dot-v1',
    'multi-qa-distilbert-cos-v1',
    'multi-qa-mpnet-base-dot-v1',
    'multi-qa-mpnet-base-cos-v1',
    'msmarco-distilbert-dot-v5',
    'msmarco-bert-base-dot-v5',
    'msmarco-distilbert-base-tas-b',
    'bert-base-nli-mean-tokens',
    'msmarco-distilbert-base-v4'
]