|
|
|
from typing import Optional
|
|
|
|
from pydantic import BaseModel
|
|
|
|
from towhee import pipe, ops, AutoPipes, AutoConfig
|
|
|
|
|
|
|
|
|
|
|
|
@AutoConfig.register
|
|
|
|
class ImageEmbeddingConfig(BaseModel):
|
|
|
|
# config for image_decode
|
|
|
|
mode: Optional[str] = 'BGR'
|
|
|
|
|
|
|
|
# config for image_embedding
|
|
|
|
model_name: Optional[str] = 'resnet50'
|
|
|
|
num_classes: Optional[int] = 1000
|
|
|
|
skip_preprocess: Optional[bool] = False
|
|
|
|
|
|
|
|
# config for triton
|
|
|
|
device: Optional[int] = -1
|
|
|
|
|
|
|
|
|
|
|
|
@AutoPipes.register
|
|
|
|
def ImageEmbedding(config=None):
|
|
|
|
if not config:
|
|
|
|
config = ImageEmbeddingConfig()
|
|
|
|
|
|
|
|
if config.device >= 0:
|
|
|
|
op_config = AutoConfig.TritonGPUConfig(device_ids=[config.device], max_batch_size=128)
|
|
|
|
else:
|
|
|
|
op_config = AutoConfig.TritonCPUConfig()
|
|
|
|
|
|
|
|
|
|
|
|
return (
|
|
|
|
pipe.input('path')
|
|
|
|
.map('path', 'img', ops.image_decode(mode=config.mode))
|
|
|
|
.map('img', 'embedding', ops.image_embedding.timm(model_name=config.model_name,
|
|
|
|
num_classes=config.num_classes,
|
|
|
|
skip_preprocess=config.skip_preprocess),
|
|
|
|
config=op_config)
|
|
|
|
.output('embedding')
|
|
|
|
)
|