towhee
/
audio-embedding
copied
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions
38 lines
1.2 KiB
38 lines
1.2 KiB
from posixpath import basename
|
|
from typing import Optional
|
|
from pydantic import BaseModel
|
|
from towhee import pipe, ops, AutoPipes, AutoConfig
|
|
|
|
|
|
@AutoConfig.register
|
|
class AudioEmbeddingConfig(BaseModel):
|
|
# config for audio_decode.ffmpeg
|
|
batch_size: Optional[int] = -1
|
|
sample_rate: Optional[float] = None
|
|
layout: Optional[str] = None
|
|
|
|
# config for audio_embedding.vggish
|
|
weights_path: Optional[str] = None
|
|
framework: Optional[str] = 'pytorch'
|
|
|
|
# config for triton
|
|
device: Optional[int] = -1
|
|
|
|
|
|
@AutoPipes.register
|
|
def AudioEmbedding(config=None):
|
|
if not config:
|
|
config = AudioEmbeddingConfig()
|
|
|
|
if config.device >= 0:
|
|
op_config = AutoConfig.TritonGPUConfig(device_ids=[config.device], max_batch_size=128)
|
|
else:
|
|
op_config = AutoConfig.TritonCPUConfig()
|
|
|
|
|
|
return (
|
|
pipe.input('path')
|
|
.map('path', 'frame', ops.audio_decode.ffmpeg(batch_size=config.batch_size, sample_rate=config.sample_rate, layout=config.layout))
|
|
.map('frame', 'vec', ops.audio_embedding.vggish(weights_path=config.weights_path, framework=config.framework), config=op_config)
|
|
.output('vec')
|
|
)
|