diff --git a/README.md b/README.md index fc9c035..52a7025 100644 --- a/README.md +++ b/README.md @@ -30,7 +30,7 @@ p = ( history=[('Who won the world series in 2020?', 'The Los Angeles Dodgers won the World Series in 2020.')] question = 'Where was it played?' -answer = p(question, [], history) +answer = p(question, [], history).get()[0] ```
diff --git a/hf_dolly.py b/hf_dolly.py index 0558779..b272c1e 100644 --- a/hf_dolly.py +++ b/hf_dolly.py @@ -35,7 +35,7 @@ class HuggingfaceDolly(PyOperator): def __call__(self, messages: List[dict]): prompt = self.parse_inputs(messages) ans = self.pipeline(prompt) - return ans + return ans[0]['generated_text'] def parse_inputs(self, messages: List[dict]): assert isinstance(messages, list), \ @@ -43,9 +43,8 @@ class HuggingfaceDolly(PyOperator): prompt = messages[-1]['question'] history = '' for m in messages[:-1]: - for k, v in m.items(): - line = k + ': ' + v + '\n' - history += line + for _, v in m.items(): + history += v + '\n' return prompt + '\n' + history