logo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

28 lines
881 B

4 years ago
import sys
from typing import NamedTuple, List
from pathlib import Path
import numpy as np
from towhee.operator import Operator
class EmbeddingConcat(Operator):
"""
"""
def __init__(self) -> None:
super().__init__()
sys.path.append(str(Path(__file__).parent))
def __call__(self, input_array_list: List[np.ndarray], input_weight_list: List[float]) -> NamedTuple('Outputs', [('feature_vector', np.ndarray)]):
concatemb = []
if input_weight_list is None or input_weight_list == []:
input_weight_list = [1.0] * len(input_array_list)
for emb_part, w in zip(input_array_list, input_weight_list):
concatemb.append(w * emb_part)
concated_feature = np.hstack(concatemb).flatten()
Outputs = NamedTuple('Outputs', [('feature_vector', np.ndarray)])
return Outputs(concated_feature)