|  | @ -25,6 +25,7 @@ from transformers import AutoTokenizer, AutoConfig, AutoModel | 
		
	
		
			
				|  |  | 
 |  |  | 
 | 
		
	
		
			
				|  |  | from towhee.operator import NNOperator |  |  | from towhee.operator import NNOperator | 
		
	
		
			
				|  |  | from towhee import register |  |  | from towhee import register | 
		
	
		
			
				|  |  |  |  |  | # from towhee.serve.triton.triton_client import TritonClient | 
		
	
		
			
				|  |  | # from towhee.dc2 import accelerate |  |  | # from towhee.dc2 import accelerate | 
		
	
		
			
				|  |  | 
 |  |  | 
 | 
		
	
		
			
				|  |  | import warnings |  |  | import warnings | 
		
	
	
		
			
				|  | @ -37,10 +38,24 @@ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' | 
		
	
		
			
				|  |  | t_logging.set_verbosity_error() |  |  | t_logging.set_verbosity_error() | 
		
	
		
			
				|  |  | 
 |  |  | 
 | 
		
	
		
			
				|  |  | 
 |  |  | 
 | 
		
	
		
			
				|  |  |  |  |  | def create_model(model_name, checkpoint_path, device): | 
		
	
		
			
				|  |  |  |  |  |     model = AutoModel.from_pretrained(model_name).to(device) | 
		
	
		
			
				|  |  |  |  |  |     if hasattr(model, 'pooler') and model.pooler: | 
		
	
		
			
				|  |  |  |  |  |         model.pooler = None | 
		
	
		
			
				|  |  |  |  |  |     if checkpoint_path: | 
		
	
		
			
				|  |  |  |  |  |         try: | 
		
	
		
			
				|  |  |  |  |  |             state_dict = torch.load(checkpoint_path, map_location=device) | 
		
	
		
			
				|  |  |  |  |  |             model.load_state_dict(state_dict) | 
		
	
		
			
				|  |  |  |  |  |         except Exception: | 
		
	
		
			
				|  |  |  |  |  |             log.error(f'Fail to load weights from {checkpoint_path}') | 
		
	
		
			
				|  |  |  |  |  |     model.eval() | 
		
	
		
			
				|  |  |  |  |  |     return model | 
		
	
		
			
				|  |  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  |  | 
 | 
		
	
		
			
				|  |  | # @accelerate |  |  | # @accelerate | 
		
	
		
			
				|  |  | class Model: |  |  | class Model: | 
		
	
		
			
				|  |  |     def __init__(self, model): |  |  |  | 
		
	
		
			
				|  |  |         self.model = model |  |  |  | 
		
	
		
			
				|  |  |  |  |  |     def __init__(self, model_name, checkpoint_path, device): | 
		
	
		
			
				|  |  |  |  |  |         self.model = create_model(model_name, checkpoint_path, device) | 
		
	
		
			
				|  |  | 
 |  |  | 
 | 
		
	
		
			
				|  |  |     def __call__(self, *args, **kwargs): |  |  |     def __call__(self, *args, **kwargs): | 
		
	
		
			
				|  |  |         outs = self.model(*args, **kwargs, return_dict=True) |  |  |         outs = self.model(*args, **kwargs, return_dict=True) | 
		
	
	
		
			
				|  | @ -79,7 +94,7 @@ class AutoTransformers(NNOperator): | 
		
	
		
			
				|  |  |         if self.model_name: |  |  |         if self.model_name: | 
		
	
		
			
				|  |  |             # model_list = self.supported_model_names() |  |  |             # model_list = self.supported_model_names() | 
		
	
		
			
				|  |  |             # assert model_name in model_list, f"Invalid model name: {model_name}. Supported model names: {model_list}" |  |  |             # assert model_name in model_list, f"Invalid model name: {model_name}. Supported model names: {model_list}" | 
		
	
		
			
				|  |  |             self.model = Model(self._model) |  |  |  | 
		
	
		
			
				|  |  |  |  |  |             self.model = Model(model_name=self.model_name, checkpoint_path=self.checkpoint_path, device=self.device) | 
		
	
		
			
				|  |  |             if tokenizer: |  |  |             if tokenizer: | 
		
	
		
			
				|  |  |                 self.tokenizer = tokenizer |  |  |                 self.tokenizer = tokenizer | 
		
	
		
			
				|  |  |             else: |  |  |             else: | 
		
	
	
		
			
				|  | @ -117,16 +132,10 @@ class AutoTransformers(NNOperator): | 
		
	
		
			
				|  |  | 
 |  |  | 
 | 
		
	
		
			
				|  |  |     @property |  |  |     @property | 
		
	
		
			
				|  |  |     def _model(self): |  |  |     def _model(self): | 
		
	
		
			
				|  |  |         model = AutoModel.from_pretrained(self.model_name).to(self.device) |  |  |  | 
		
	
		
			
				|  |  |         if hasattr(model, 'pooler') and model.pooler: |  |  |  | 
		
	
		
			
				|  |  |             model.pooler = None |  |  |  | 
		
	
		
			
				|  |  |         if self.checkpoint_path: |  |  |  | 
		
	
		
			
				|  |  |             try: |  |  |  | 
		
	
		
			
				|  |  |                 state_dict = torch.load(self.checkpoint_path, map_location=self.device) |  |  |  | 
		
	
		
			
				|  |  |                 model.load_state_dict(state_dict) |  |  |  | 
		
	
		
			
				|  |  |             except Exception: |  |  |  | 
		
	
		
			
				|  |  |                 log.error(f'Fail to load weights from {self.checkpoint_path}') |  |  |  | 
		
	
		
			
				|  |  |         model.eval() |  |  |  | 
		
	
		
			
				|  |  |  |  |  |         # if isinstance(self.model, TritonClient): | 
		
	
		
			
				|  |  |  |  |  |         #     model = create_model(self.model_name, self.checkpoint_path, self.device) | 
		
	
		
			
				|  |  |  |  |  |         # else: | 
		
	
		
			
				|  |  |  |  |  |         model = self.model.model | 
		
	
		
			
				|  |  |         return model |  |  |         return model | 
		
	
		
			
				|  |  | 
 |  |  | 
 | 
		
	
		
			
				|  |  |     @property |  |  |     @property | 
		
	
	
		
			
				|  | 
 |