Browse Source
        
      
      update
      
        Signed-off-by: junjie.jiang <junjie.jiang@zilliz.com>
      
      
        main
      
      
     
    
    
    
	
		
			
				 1 changed files with 
13 additions and 
6 deletions
			 
			
		 
		
			
				- 
					
					
					 
					template_prompt.py
				
				
				
					
						
							
								
									
	
		
			
				
					|  |  | @ -1,18 +1,25 @@ | 
			
		
	
		
			
				
					|  |  |  | from typing import List, Tuple, Dict, Optional | 
			
		
	
		
			
				
					|  |  |  | from typing import List | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  | from towhee.operator import PyOperator | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  | class TemplatePrompt(PyOperator): | 
			
		
	
		
			
				
					|  |  |  |     def __init__(self, temp: str): | 
			
		
	
		
			
				
					|  |  |  |     def __init__(self, temp: str, keys: List[str]): | 
			
		
	
		
			
				
					|  |  |  |         super().__init__() | 
			
		
	
		
			
				
					|  |  |  |         self._template = temp | 
			
		
	
		
			
				
					|  |  |  |         self._keys = keys | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |     def __call__(self, **kwargs) -> List[Dict[str, str]]: | 
			
		
	
		
			
				
					|  |  |  |         history = kwargs.get('history', []) | 
			
		
	
		
			
				
					|  |  |  |         prompt_str = self._template.format(**kwargs) | 
			
		
	
		
			
				
					|  |  |  |     def __call__(self, *args) -> List[Dict[str, str]]: | 
			
		
	
		
			
				
					|  |  |  |         if len(self._keys) == len(args): | 
			
		
	
		
			
				
					|  |  |  |             history = [] | 
			
		
	
		
			
				
					|  |  |  |         else: | 
			
		
	
		
			
				
					|  |  |  |             history = args[-1] | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |         kws = {(item[0], item[1]) for item in zip(self._keys, args)} | 
			
		
	
		
			
				
					|  |  |  |         prompt_str = self._template.format(**kws) | 
			
		
	
		
			
				
					|  |  |  |         ret = [{'question': prompt_str}] | 
			
		
	
		
			
				
					|  |  |  |         if not isinstance(history, list): | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |         if not history: | 
			
		
	
		
			
				
					|  |  |  |             return ret | 
			
		
	
		
			
				
					|  |  |  |         else: | 
			
		
	
		
			
				
					|  |  |  |             history_data = [] | 
			
		
	
	
		
			
				
					|  |  | 
 |