|
@ -107,8 +107,10 @@ class AutoTransformers(NNOperator): |
|
|
tokenizer: object = None, |
|
|
tokenizer: object = None, |
|
|
pool: str = 'mean', |
|
|
pool: str = 'mean', |
|
|
device: str = None, |
|
|
device: str = None, |
|
|
|
|
|
return_usage: bool = False |
|
|
): |
|
|
): |
|
|
super().__init__() |
|
|
super().__init__() |
|
|
|
|
|
self.return_usage = return_usage |
|
|
if pool not in ['mean', 'cls']: |
|
|
if pool not in ['mean', 'cls']: |
|
|
log.warning('Invalid pool %s, using mean pooling instead.', pool) |
|
|
log.warning('Invalid pool %s, using mean pooling instead.', pool) |
|
|
pool = 'mean' |
|
|
pool = 'mean' |
|
@ -152,6 +154,8 @@ class AutoTransformers(NNOperator): |
|
|
except Exception as e: |
|
|
except Exception as e: |
|
|
log.error(f'Invalid input for the model: {self.model_name}') |
|
|
log.error(f'Invalid input for the model: {self.model_name}') |
|
|
raise e |
|
|
raise e |
|
|
|
|
|
|
|
|
|
|
|
num_tokens = outs.size(1) |
|
|
if self.pool == 'mean': |
|
|
if self.pool == 'mean': |
|
|
outs = self.mean_pool(outs, inputs) |
|
|
outs = self.mean_pool(outs, inputs) |
|
|
elif self.pool == 'cls': |
|
|
elif self.pool == 'cls': |
|
@ -161,6 +165,9 @@ class AutoTransformers(NNOperator): |
|
|
features = features.squeeze(0) |
|
|
features = features.squeeze(0) |
|
|
else: |
|
|
else: |
|
|
features = list(features) |
|
|
features = list(features) |
|
|
|
|
|
|
|
|
|
|
|
if self.return_usage: |
|
|
|
|
|
return {'data': features, 'token_usage': num_tokens} |
|
|
return features |
|
|
return features |
|
|
|
|
|
|
|
|
@property |
|
|
@property |
|
|