From fd410c19e28e9b2d36419c9995c4f8a58103cb17 Mon Sep 17 00:00:00 2001 From: Jael Gu Date: Mon, 31 Jul 2023 17:59:34 +0800 Subject: [PATCH] Replace with chat completion Signed-off-by: Jael Gu --- README.md | 2 +- llama2.py | 26 +++++++++++--------------- 2 files changed, 12 insertions(+), 16 deletions(-) diff --git a/README.md b/README.md index c3f5972..45ef1be 100644 --- a/README.md +++ b/README.md @@ -25,7 +25,7 @@ Use the default model to continue the conversation from given messages. ```python from towhee import ops -chat = ops.LLM.Llama_2('llama-2-13b-chat', max_tokens=512) +chat = ops.LLM.Llama_2('llama-2-13b-chat', n_ctx=4096, max_tokens=200) message = [ {'system': 'You are a very helpful assistant.'}, diff --git a/llama2.py b/llama2.py index f8b5fcc..ababbde 100644 --- a/llama2.py +++ b/llama2.py @@ -39,36 +39,32 @@ class LlamaCpp(PyOperator): self.model_path = model_name_or_file assert os.path.isfile(self.model_path), f'Invalid model path: {self.model_path}' - self.model = Llama(model_path=self.model_path) + init_kwargs = {} + for k in vars(Llama).keys() & self.kwargs.keys(): + init_kwargs[k] = self.kwargs.pop(k) + self.model = Llama(model_path=self.model_path, **init_kwargs) def __call__(self, messages: List[dict]): - prompt = self.parse_inputs(messages) - resp = self.model(prompt, **self.kwargs) + messages = self.parse_inputs(messages) + resp = self.model.create_chat_completion(messages, **self.kwargs) answer = self.parse_outputs(resp) return answer def parse_inputs(self, messages: List[dict]): assert isinstance(messages, list), \ 'Inputs must be a list of dictionaries with keys from ["system", "question", "answer"].' - prompt = '' - question = messages.pop(-1) - assert len(question) == 1 and 'question' in question.keys() - question = question['question'] + new_messages = [] for m in messages: for k, v in m.items(): if k == 'system': - prompt += f'''[INST] <> {v} <>\n''' + new_messages.append({'role': 'system', 'content': v}) elif k == 'question': - prompt += f''' {v} [/INST]\n''' + new_messages.append({'role': 'user', 'content': v}) elif k == 'answer': - prompt += f''' {v} ''' + new_messages.append({'role': 'assistant', 'content': v}) else: raise KeyError(f'Invalid key of message: {k}') - if len(prompt) > 0: - prompt = ' ' + prompt + ' ' + f' [INST] {question} [/INST]' - else: - prompt = question - return prompt + return new_messages def parse_outputs(self, response): return response['choices'][0]['text']