|
|
@ -39,36 +39,32 @@ class LlamaCpp(PyOperator): |
|
|
|
self.model_path = model_name_or_file |
|
|
|
assert os.path.isfile(self.model_path), f'Invalid model path: {self.model_path}' |
|
|
|
|
|
|
|
self.model = Llama(model_path=self.model_path) |
|
|
|
init_kwargs = {} |
|
|
|
for k in vars(Llama).keys() & self.kwargs.keys(): |
|
|
|
init_kwargs[k] = self.kwargs.pop(k) |
|
|
|
self.model = Llama(model_path=self.model_path, **init_kwargs) |
|
|
|
|
|
|
|
def __call__(self, messages: List[dict]): |
|
|
|
prompt = self.parse_inputs(messages) |
|
|
|
resp = self.model(prompt, **self.kwargs) |
|
|
|
messages = self.parse_inputs(messages) |
|
|
|
resp = self.model.create_chat_completion(messages, **self.kwargs) |
|
|
|
answer = self.parse_outputs(resp) |
|
|
|
return answer |
|
|
|
|
|
|
|
def parse_inputs(self, messages: List[dict]): |
|
|
|
assert isinstance(messages, list), \ |
|
|
|
'Inputs must be a list of dictionaries with keys from ["system", "question", "answer"].' |
|
|
|
prompt = '' |
|
|
|
question = messages.pop(-1) |
|
|
|
assert len(question) == 1 and 'question' in question.keys() |
|
|
|
question = question['question'] |
|
|
|
new_messages = [] |
|
|
|
for m in messages: |
|
|
|
for k, v in m.items(): |
|
|
|
if k == 'system': |
|
|
|
prompt += f'''[INST] <<SYS>> {v} <</SYS>>\n''' |
|
|
|
new_messages.append({'role': 'system', 'content': v}) |
|
|
|
elif k == 'question': |
|
|
|
prompt += f''' {v} [/INST]\n''' |
|
|
|
new_messages.append({'role': 'user', 'content': v}) |
|
|
|
elif k == 'answer': |
|
|
|
prompt += f''' {v} ''' |
|
|
|
new_messages.append({'role': 'assistant', 'content': v}) |
|
|
|
else: |
|
|
|
raise KeyError(f'Invalid key of message: {k}') |
|
|
|
if len(prompt) > 0: |
|
|
|
prompt = '<s> ' + prompt + ' </s>' + f'<s> [INST] {question} [/INST]' |
|
|
|
else: |
|
|
|
prompt = question |
|
|
|
return prompt |
|
|
|
return new_messages |
|
|
|
|
|
|
|
def parse_outputs(self, response): |
|
|
|
return response['choices'][0]['text'] |
|
|
|