# Copyright 2023 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from typing import List

from towhee.operator.base import PyOperator

import erniebot


class ErnieChat(PyOperator):
    '''Wrapper of Ernie Bot SDK'''

    def __init__(self,
                 model_name: str = 'ernie-bot-turbo',
                 eb_api_type: str = None,
                 eb_access_token: str = None,
                 **kwargs
                 ):
        erniebot.api_type = eb_api_type or os.getenv('EB_API_TYPE')
        erniebot.access_token = eb_access_token or os.getenv('EB_ACCESS_TOKEN')
        self._model = model_name
        self.stream = kwargs.pop('stream') if 'stream' in kwargs else False
        self.kwargs = kwargs

    def __call__(self, messages: List[dict]):
        messages = self.parse_inputs(messages)
        response = erniebot.ChatCompletion.create(
            model=self._model,
            messages=messages,
            stream=self.stream,
            **self.kwargs
        )
        if self.stream:
            return self.stream_output(response)
        else:
            answer = response.result
            return answer

    def parse_inputs(self, messages: List[dict]):
        assert isinstance(messages, list), \
            'Inputs must be a list of dictionaries with keys from ["system", "question", "answer"].'
        new_messages = []
        for m in messages:
            if ('role' and 'content' in m) and (m['role'] in ['system', 'assistant', 'user']):
                new_messages.append(m)
            else:
                for k, v in m.items():
                    if k == 'question':
                        new_m = {'role': 'user', 'content': v}
                    elif k == 'answer':
                        new_m = {'role': 'assistant', 'content': v}
                    elif k == 'system':
                        new_m = {'role': 'system', 'content': v}
                    else:
                        raise KeyError(
                            'Invalid message key: only accept key value from ["system", "question", "answer"].')
                    new_messages.append(new_m)
        return new_messages

    def stream_output(self, response):
        for resp in response:
            yield resp.result

    @staticmethod
    def supported_model_names():
        model_list = [
            'ernie-bot',
            'ernie-bot-turbo'
        ]
        model_list.sort()
        return model_list