|  |  |  | from typing import List, Dict | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | from towhee.operator import PyOperator | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class TemplatePrompt(PyOperator): | 
					
						
							|  |  |  |     def __init__(self, temp: str, keys: List[str], sys_msg: str = None): | 
					
						
							|  |  |  |         super().__init__() | 
					
						
							|  |  |  |         self._template = temp | 
					
						
							|  |  |  |         self._keys = keys | 
					
						
							|  |  |  |         self._sys_msg = sys_msg | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def __call__(self, *args) -> List[Dict[str, str]]: | 
					
						
							|  |  |  |         if self._sys_msg: | 
					
						
							|  |  |  |             system_message = [{'system': self._sys_msg}] | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             system_message = [] | 
					
						
							|  |  |  |         if len(self._keys) == len(args): | 
					
						
							|  |  |  |             history = [] | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             history = args[-1] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         kws = dict((item[0], item[1]) for item in zip(self._keys, args)) | 
					
						
							|  |  |  |         prompt_str = self._template.format(**kws) | 
					
						
							|  |  |  |         ret = [{'question': prompt_str}] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if not history: | 
					
						
							|  |  |  |             return system_message + ret | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             history_data = [] | 
					
						
							|  |  |  |             for item in history: | 
					
						
							|  |  |  |                 history_data.append({'question': item[0], 'answer': item[1]}) | 
					
						
							|  |  |  |             return system_message + history_data + ret |