towhee
/
image-embedding
copied
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions
38 lines
1.2 KiB
38 lines
1.2 KiB
2 years ago
|
from towhee import pipe, ops, AutoPipes, AutoConfig
|
||
|
|
||
|
|
||
|
@AutoConfig.register
|
||
|
class ImageEmbeddingConfig:
|
||
|
def __init__(self):
|
||
|
# config for audio_decode.ffmpeg
|
||
|
self.mode = 'BGR'
|
||
|
|
||
|
# config for audio_embedding.vggish
|
||
|
self.model_name = 'resnet50'
|
||
|
self.num_classes: int = 1000
|
||
|
self.skip_preprocess: bool = False
|
||
|
|
||
|
# config for triton
|
||
|
self.device = -1
|
||
|
|
||
|
|
||
|
@AutoPipes.register
|
||
|
def ImageEmbedding(config=None):
|
||
|
if not config:
|
||
|
config = ImageEmbeddingConfig()
|
||
|
|
||
|
if config.device >= 0:
|
||
|
op_config = AutoConfig.TritonGPUConfig(device_ids=[config.device], max_batch_size=128)
|
||
|
else:
|
||
|
op_config = AutoConfig.TritonCPUConfig()
|
||
|
|
||
|
|
||
|
return (
|
||
|
pipe.input('path')
|
||
|
.map('path', 'img', ops.image_decode(mode=config.mode))
|
||
|
.map('img', 'embedding', ops.image_embedding.timm(model_name=config.model_name,
|
||
|
num_classes=config.num_classes,
|
||
|
skip_preprocess=config.skip_preprocess),
|
||
|
config=op_config)
|
||
|
.output('embedding')
|
||
|
)
|