towhee
/
embedding-concat
copied
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions
27 lines
881 B
27 lines
881 B
import sys
|
|
from typing import NamedTuple, List
|
|
from pathlib import Path
|
|
import numpy as np
|
|
|
|
from towhee.operator import Operator
|
|
|
|
|
|
class EmbeddingConcat(Operator):
|
|
"""
|
|
|
|
"""
|
|
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
sys.path.append(str(Path(__file__).parent))
|
|
|
|
def __call__(self, input_array_list: List[np.ndarray], input_weight_list: List[float]) -> NamedTuple('Outputs', [('feature_vector', np.ndarray)]):
|
|
concatemb = []
|
|
if input_weight_list is None or input_weight_list == []:
|
|
input_weight_list = [1.0] * len(input_array_list)
|
|
for emb_part, w in zip(input_array_list, input_weight_list):
|
|
concatemb.append(w * emb_part)
|
|
concated_feature = np.hstack(concatemb).flatten()
|
|
Outputs = NamedTuple('Outputs', [('feature_vector', np.ndarray)])
|
|
return Outputs(concated_feature)
|
|
|